Almost complete KDB
This commit is contained in:
parent
8b0aa5ccfb
commit
786d781e29
@ -12,7 +12,6 @@ namespace bayesnet {
|
||||
this->features = features;
|
||||
this->className = className;
|
||||
this->states = states;
|
||||
cout << "Checking fit parameters" << endl;
|
||||
checkFitParameters();
|
||||
train();
|
||||
return *this;
|
||||
|
49
src/KDB.cc
49
src/KDB.cc
@ -4,6 +4,14 @@
|
||||
namespace bayesnet {
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
vector<int> argsort(vector<float>& nums)
|
||||
{
|
||||
int n = nums.size();
|
||||
vector<int> indices(n);
|
||||
iota(indices.begin(), indices.end(), 0);
|
||||
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
|
||||
return indices;
|
||||
}
|
||||
KDB::KDB(int k) : BaseClassifier(Network()), k(k) {}
|
||||
void KDB::train()
|
||||
{
|
||||
@ -31,14 +39,45 @@ namespace bayesnet {
|
||||
cout << "Computing mutual information between features and class" << endl;
|
||||
auto n_classes = states[className].size();
|
||||
auto metrics = Metrics(dataset, features, className, n_classes);
|
||||
vector <float> mi;
|
||||
for (auto i = 0; i < features.size(); i++) {
|
||||
Tensor firstFeature = X.index({ "...", i });
|
||||
Tensor secondFeature = y;
|
||||
double mi = metrics.mutualInformation(firstFeature, y);
|
||||
cout << "Mutual information between " << features[i] << " and " << className << " is " << mi << endl;
|
||||
mi.push_back(metrics.mutualInformation(firstFeature, y));
|
||||
cout << "Mutual information between " << features[i] << " and " << className << " is " << mi[i] << endl;
|
||||
}
|
||||
// 2. Compute class conditional mutual information I(Xi;XjIC), f or each
|
||||
auto conditionalEdgeWeights = metrics.conditionalEdge();
|
||||
cout << "Conditional edge weights" << endl;
|
||||
cout << conditionalEdgeWeights << endl;
|
||||
// 3. Let the used variable list, S, be empty.
|
||||
vector<int> S;
|
||||
// 4. Let the DAG network being constructed, BN, begin with a single
|
||||
// class node, C.
|
||||
model.addNode(className, states[className].size());
|
||||
cout << "Adding node " << className << " to the network" << endl;
|
||||
// 5. Repeat until S includes all domain features
|
||||
// 5.1. Select feature Xmax which is not in S and has the largest value
|
||||
// I(Xmax;C).
|
||||
auto order = argsort(mi);
|
||||
for (auto idx : order) {
|
||||
cout << idx << " " << mi[idx] << endl;
|
||||
// 5.2. Add a node to BN representing Xmax.
|
||||
model.addNode(features[idx], states[features[idx]].size());
|
||||
// 5.3. Add an arc from C to Xmax in BN.
|
||||
model.addEdge(className, features[idx]);
|
||||
// 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
|
||||
// the highest value for I(Xmax;X,jC).
|
||||
// auto conditionalEdgeWeightsAccessor = conditionalEdgeWeights.accessor<float, 2>();
|
||||
// auto conditionalEdgeWeightsSorted = conditionalEdgeWeightsAccessor[idx].sort();
|
||||
// auto conditionalEdgeWeightsSortedAccessor = conditionalEdgeWeightsSorted.accessor<float, 1>();
|
||||
// for (auto i = 0; i < k; ++i) {
|
||||
// auto index = conditionalEdgeWeightsSortedAccessor[i].item<int>();
|
||||
// model.addEdge(features[idx], features[index]);
|
||||
// }
|
||||
// 5.5. Add Xmax to S.
|
||||
S.push_back(idx);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
@ -30,7 +30,7 @@ namespace bayesnet {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
vector<float> Metrics::conditionalEdgeWeights()
|
||||
torch::Tensor Metrics::conditionalEdge()
|
||||
{
|
||||
auto result = vector<double>();
|
||||
auto source = vector<string>(features);
|
||||
@ -65,6 +65,11 @@ namespace bayesnet {
|
||||
matrix[x][y] = result[i];
|
||||
matrix[y][x] = result[i];
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
vector<float> Metrics::conditionalEdgeWeights()
|
||||
{
|
||||
auto matrix = conditionalEdge();
|
||||
std::vector<float> v(matrix.data_ptr<float>(), matrix.data_ptr<float>() + matrix.numel());
|
||||
return v;
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ namespace bayesnet {
|
||||
Metrics(torch::Tensor&, vector<string>&, string&, int);
|
||||
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
|
||||
vector<float> conditionalEdgeWeights();
|
||||
torch::Tensor conditionalEdge();
|
||||
};
|
||||
}
|
||||
#endif
|
Loading…
Reference in New Issue
Block a user