Almost complete KDB

This commit is contained in:
Ricardo Montañana Gómez 2023-07-13 03:44:33 +02:00
parent 8b0aa5ccfb
commit 786d781e29
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
4 changed files with 51 additions and 7 deletions

View File

@ -12,7 +12,6 @@ namespace bayesnet {
this->features = features;
this->className = className;
this->states = states;
cout << "Checking fit parameters" << endl;
checkFitParameters();
train();
return *this;

View File

@ -4,6 +4,14 @@
namespace bayesnet {
using namespace std;
using namespace torch;
vector<int> argsort(vector<float>& nums)
{
int n = nums.size();
vector<int> indices(n);
iota(indices.begin(), indices.end(), 0);
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
return indices;
}
KDB::KDB(int k) : BaseClassifier(Network()), k(k) {}
void KDB::train()
{
@ -31,14 +39,45 @@ namespace bayesnet {
cout << "Computing mutual information between features and class" << endl;
auto n_classes = states[className].size();
auto metrics = Metrics(dataset, features, className, n_classes);
vector <float> mi;
for (auto i = 0; i < features.size(); i++) {
Tensor firstFeature = X.index({ "...", i });
Tensor secondFeature = y;
double mi = metrics.mutualInformation(firstFeature, y);
cout << "Mutual information between " << features[i] << " and " << className << " is " << mi << endl;
mi.push_back(metrics.mutualInformation(firstFeature, y));
cout << "Mutual information between " << features[i] << " and " << className << " is " << mi[i] << endl;
}
// 2. Compute class conditional mutual information I(Xi;XjIC), f or each
auto conditionalEdgeWeights = metrics.conditionalEdge();
cout << "Conditional edge weights" << endl;
cout << conditionalEdgeWeights << endl;
// 3. Let the used variable list, S, be empty.
vector<int> S;
// 4. Let the DAG network being constructed, BN, begin with a single
// class node, C.
model.addNode(className, states[className].size());
cout << "Adding node " << className << " to the network" << endl;
// 5. Repeat until S includes all domain features
// 5.1. Select feature Xmax which is not in S and has the largest value
// I(Xmax;C).
auto order = argsort(mi);
for (auto idx : order) {
cout << idx << " " << mi[idx] << endl;
// 5.2. Add a node to BN representing Xmax.
model.addNode(features[idx], states[features[idx]].size());
// 5.3. Add an arc from C to Xmax in BN.
model.addEdge(className, features[idx]);
// 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
// the highest value for I(Xmax;X,jC).
// auto conditionalEdgeWeightsAccessor = conditionalEdgeWeights.accessor<float, 2>();
// auto conditionalEdgeWeightsSorted = conditionalEdgeWeightsAccessor[idx].sort();
// auto conditionalEdgeWeightsSortedAccessor = conditionalEdgeWeightsSorted.accessor<float, 1>();
// for (auto i = 0; i < k; ++i) {
// auto index = conditionalEdgeWeightsSortedAccessor[i].item<int>();
// model.addEdge(features[idx], features[index]);
// }
// 5.5. Add Xmax to S.
S.push_back(idx);
}
}
}
}

View File

@ -30,7 +30,7 @@ namespace bayesnet {
}
return result;
}
vector<float> Metrics::conditionalEdgeWeights()
torch::Tensor Metrics::conditionalEdge()
{
auto result = vector<double>();
auto source = vector<string>(features);
@ -65,6 +65,11 @@ namespace bayesnet {
matrix[x][y] = result[i];
matrix[y][x] = result[i];
}
return matrix;
}
vector<float> Metrics::conditionalEdgeWeights()
{
auto matrix = conditionalEdge();
std::vector<float> v(matrix.data_ptr<float>(), matrix.data_ptr<float>() + matrix.numel());
return v;
}

View File

@ -19,6 +19,7 @@ namespace bayesnet {
Metrics(torch::Tensor&, vector<string>&, string&, int);
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
vector<float> conditionalEdgeWeights();
torch::Tensor conditionalEdge();
};
}
#endif