Almost complete KDB
This commit is contained in:
parent
8b0aa5ccfb
commit
786d781e29
@ -12,7 +12,6 @@ namespace bayesnet {
|
|||||||
this->features = features;
|
this->features = features;
|
||||||
this->className = className;
|
this->className = className;
|
||||||
this->states = states;
|
this->states = states;
|
||||||
cout << "Checking fit parameters" << endl;
|
|
||||||
checkFitParameters();
|
checkFitParameters();
|
||||||
train();
|
train();
|
||||||
return *this;
|
return *this;
|
||||||
|
49
src/KDB.cc
49
src/KDB.cc
@ -4,6 +4,14 @@
|
|||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace torch;
|
using namespace torch;
|
||||||
|
vector<int> argsort(vector<float>& nums)
|
||||||
|
{
|
||||||
|
int n = nums.size();
|
||||||
|
vector<int> indices(n);
|
||||||
|
iota(indices.begin(), indices.end(), 0);
|
||||||
|
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
KDB::KDB(int k) : BaseClassifier(Network()), k(k) {}
|
KDB::KDB(int k) : BaseClassifier(Network()), k(k) {}
|
||||||
void KDB::train()
|
void KDB::train()
|
||||||
{
|
{
|
||||||
@ -31,14 +39,45 @@ namespace bayesnet {
|
|||||||
cout << "Computing mutual information between features and class" << endl;
|
cout << "Computing mutual information between features and class" << endl;
|
||||||
auto n_classes = states[className].size();
|
auto n_classes = states[className].size();
|
||||||
auto metrics = Metrics(dataset, features, className, n_classes);
|
auto metrics = Metrics(dataset, features, className, n_classes);
|
||||||
|
vector <float> mi;
|
||||||
for (auto i = 0; i < features.size(); i++) {
|
for (auto i = 0; i < features.size(); i++) {
|
||||||
Tensor firstFeature = X.index({ "...", i });
|
Tensor firstFeature = X.index({ "...", i });
|
||||||
Tensor secondFeature = y;
|
mi.push_back(metrics.mutualInformation(firstFeature, y));
|
||||||
double mi = metrics.mutualInformation(firstFeature, y);
|
cout << "Mutual information between " << features[i] << " and " << className << " is " << mi[i] << endl;
|
||||||
cout << "Mutual information between " << features[i] << " and " << className << " is " << mi << endl;
|
}
|
||||||
|
// 2. Compute class conditional mutual information I(Xi;XjIC), f or each
|
||||||
|
auto conditionalEdgeWeights = metrics.conditionalEdge();
|
||||||
|
cout << "Conditional edge weights" << endl;
|
||||||
|
cout << conditionalEdgeWeights << endl;
|
||||||
|
// 3. Let the used variable list, S, be empty.
|
||||||
|
vector<int> S;
|
||||||
|
// 4. Let the DAG network being constructed, BN, begin with a single
|
||||||
|
// class node, C.
|
||||||
|
model.addNode(className, states[className].size());
|
||||||
|
cout << "Adding node " << className << " to the network" << endl;
|
||||||
|
// 5. Repeat until S includes all domain features
|
||||||
|
// 5.1. Select feature Xmax which is not in S and has the largest value
|
||||||
|
// I(Xmax;C).
|
||||||
|
auto order = argsort(mi);
|
||||||
|
for (auto idx : order) {
|
||||||
|
cout << idx << " " << mi[idx] << endl;
|
||||||
|
// 5.2. Add a node to BN representing Xmax.
|
||||||
|
model.addNode(features[idx], states[features[idx]].size());
|
||||||
|
// 5.3. Add an arc from C to Xmax in BN.
|
||||||
|
model.addEdge(className, features[idx]);
|
||||||
|
// 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
|
||||||
|
// the highest value for I(Xmax;X,jC).
|
||||||
|
// auto conditionalEdgeWeightsAccessor = conditionalEdgeWeights.accessor<float, 2>();
|
||||||
|
// auto conditionalEdgeWeightsSorted = conditionalEdgeWeightsAccessor[idx].sort();
|
||||||
|
// auto conditionalEdgeWeightsSortedAccessor = conditionalEdgeWeightsSorted.accessor<float, 1>();
|
||||||
|
// for (auto i = 0; i < k; ++i) {
|
||||||
|
// auto index = conditionalEdgeWeightsSortedAccessor[i].item<int>();
|
||||||
|
// model.addEdge(features[idx], features[index]);
|
||||||
|
// }
|
||||||
|
// 5.5. Add Xmax to S.
|
||||||
|
S.push_back(idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -30,7 +30,7 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
vector<float> Metrics::conditionalEdgeWeights()
|
torch::Tensor Metrics::conditionalEdge()
|
||||||
{
|
{
|
||||||
auto result = vector<double>();
|
auto result = vector<double>();
|
||||||
auto source = vector<string>(features);
|
auto source = vector<string>(features);
|
||||||
@ -65,6 +65,11 @@ namespace bayesnet {
|
|||||||
matrix[x][y] = result[i];
|
matrix[x][y] = result[i];
|
||||||
matrix[y][x] = result[i];
|
matrix[y][x] = result[i];
|
||||||
}
|
}
|
||||||
|
return matrix;
|
||||||
|
}
|
||||||
|
vector<float> Metrics::conditionalEdgeWeights()
|
||||||
|
{
|
||||||
|
auto matrix = conditionalEdge();
|
||||||
std::vector<float> v(matrix.data_ptr<float>(), matrix.data_ptr<float>() + matrix.numel());
|
std::vector<float> v(matrix.data_ptr<float>(), matrix.data_ptr<float>() + matrix.numel());
|
||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ namespace bayesnet {
|
|||||||
Metrics(torch::Tensor&, vector<string>&, string&, int);
|
Metrics(torch::Tensor&, vector<string>&, string&, int);
|
||||||
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
|
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
|
||||||
vector<float> conditionalEdgeWeights();
|
vector<float> conditionalEdgeWeights();
|
||||||
|
torch::Tensor conditionalEdge();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
Loading…
Reference in New Issue
Block a user