From 786d781e292622e67676af7067ae5ee3ed216206 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 13 Jul 2023 03:44:33 +0200 Subject: [PATCH] Almost complete KDB --- src/BaseClassifier.cc | 1 - src/KDB.cc | 49 ++++++++++++++++++++++++++++++++++++++----- src/Metrics.cc | 7 ++++++- src/Metrics.hpp | 1 + 4 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/BaseClassifier.cc b/src/BaseClassifier.cc index 6b67cac..0673b0d 100644 --- a/src/BaseClassifier.cc +++ b/src/BaseClassifier.cc @@ -12,7 +12,6 @@ namespace bayesnet { this->features = features; this->className = className; this->states = states; - cout << "Checking fit parameters" << endl; checkFitParameters(); train(); return *this; diff --git a/src/KDB.cc b/src/KDB.cc index df6ad77..178f1f0 100644 --- a/src/KDB.cc +++ b/src/KDB.cc @@ -4,6 +4,14 @@ namespace bayesnet { using namespace std; using namespace torch; + vector argsort(vector& nums) + { + int n = nums.size(); + vector indices(n); + iota(indices.begin(), indices.end(), 0); + sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];}); + return indices; + } KDB::KDB(int k) : BaseClassifier(Network()), k(k) {} void KDB::train() { @@ -31,14 +39,45 @@ namespace bayesnet { cout << "Computing mutual information between features and class" << endl; auto n_classes = states[className].size(); auto metrics = Metrics(dataset, features, className, n_classes); + vector mi; for (auto i = 0; i < features.size(); i++) { Tensor firstFeature = X.index({ "...", i }); - Tensor secondFeature = y; - double mi = metrics.mutualInformation(firstFeature, y); - cout << "Mutual information between " << features[i] << " and " << className << " is " << mi << endl; - + mi.push_back(metrics.mutualInformation(firstFeature, y)); + cout << "Mutual information between " << features[i] << " and " << className << " is " << mi[i] << endl; + } + // 2. Compute class conditional mutual information I(Xi;XjIC), f or each + auto conditionalEdgeWeights = metrics.conditionalEdge(); + cout << "Conditional edge weights" << endl; + cout << conditionalEdgeWeights << endl; + // 3. Let the used variable list, S, be empty. + vector S; + // 4. Let the DAG network being constructed, BN, begin with a single + // class node, C. + model.addNode(className, states[className].size()); + cout << "Adding node " << className << " to the network" << endl; + // 5. Repeat until S includes all domain features + // 5.1. Select feature Xmax which is not in S and has the largest value + // I(Xmax;C). + auto order = argsort(mi); + for (auto idx : order) { + cout << idx << " " << mi[idx] << endl; + // 5.2. Add a node to BN representing Xmax. + model.addNode(features[idx], states[features[idx]].size()); + // 5.3. Add an arc from C to Xmax in BN. + model.addEdge(className, features[idx]); + // 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with + // the highest value for I(Xmax;X,jC). + // auto conditionalEdgeWeightsAccessor = conditionalEdgeWeights.accessor(); + // auto conditionalEdgeWeightsSorted = conditionalEdgeWeightsAccessor[idx].sort(); + // auto conditionalEdgeWeightsSortedAccessor = conditionalEdgeWeightsSorted.accessor(); + // for (auto i = 0; i < k; ++i) { + // auto index = conditionalEdgeWeightsSortedAccessor[i].item(); + // model.addEdge(features[idx], features[index]); + // } + // 5.5. Add Xmax to S. + S.push_back(idx); } - } + } \ No newline at end of file diff --git a/src/Metrics.cc b/src/Metrics.cc index dacc0cb..f8174c1 100644 --- a/src/Metrics.cc +++ b/src/Metrics.cc @@ -30,7 +30,7 @@ namespace bayesnet { } return result; } - vector Metrics::conditionalEdgeWeights() + torch::Tensor Metrics::conditionalEdge() { auto result = vector(); auto source = vector(features); @@ -65,6 +65,11 @@ namespace bayesnet { matrix[x][y] = result[i]; matrix[y][x] = result[i]; } + return matrix; + } + vector Metrics::conditionalEdgeWeights() + { + auto matrix = conditionalEdge(); std::vector v(matrix.data_ptr(), matrix.data_ptr() + matrix.numel()); return v; } diff --git a/src/Metrics.hpp b/src/Metrics.hpp index 2754f0a..f939e44 100644 --- a/src/Metrics.hpp +++ b/src/Metrics.hpp @@ -19,6 +19,7 @@ namespace bayesnet { Metrics(torch::Tensor&, vector&, string&, int); Metrics(const vector>&, const vector&, const vector&, const string&, const int); vector conditionalEdgeWeights(); + torch::Tensor conditionalEdge(); }; } #endif \ No newline at end of file