From 3f09d474f9360494683775f4d28d753ba1b54706 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Fri, 14 Jul 2023 00:10:55 +0200 Subject: [PATCH] Refactor BaseClassifier and begin TAN impl. --- src/BaseClassifier.cc | 24 +++++++++++++++++++++--- src/BaseClassifier.h | 9 +++++---- src/CMakeLists.txt | 2 +- src/KDB.cc | 33 +-------------------------------- src/KDB.h | 1 - src/Metrics.cc | 8 ++++++++ src/Metrics.hpp | 19 +++++++++++-------- src/TAN.cc | 25 +++++++++++++++++++++++++ src/TAN.h | 15 +++++++++++++++ 9 files changed, 87 insertions(+), 49 deletions(-) create mode 100644 src/TAN.cc create mode 100644 src/TAN.h diff --git a/src/BaseClassifier.cc b/src/BaseClassifier.cc index 57d02d3..5f7b161 100644 --- a/src/BaseClassifier.cc +++ b/src/BaseClassifier.cc @@ -4,7 +4,7 @@ namespace bayesnet { using namespace std; using namespace torch; - BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0) {} + BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()) {} BaseClassifier& BaseClassifier::build(vector& features, string className, map>& states) { @@ -13,6 +13,8 @@ namespace bayesnet { this->className = className; this->states = states; checkFitParameters(); + auto n_classes = states[className].size(); + metrics = Metrics(dataset, features, className, n_classes); train(); return *this; } @@ -51,6 +53,14 @@ namespace bayesnet { } } } + vector BaseClassifier::argsort(vector& nums) + { + int n = nums.size(); + vector indices(n); + iota(indices.begin(), indices.end(), 0); + sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];}); + return indices; + } vector> tensorToVector(const torch::Tensor& tensor) { // convert mxn tensor to nxm vector @@ -86,8 +96,16 @@ namespace bayesnet { Tensor y_pred = predict(X); return (y_pred == y).sum().item() / y.size(0); } - void BaseClassifier::show() + vector BaseClassifier::show() { - model.show(); + return model.show(); + } + void BaseClassifier::addNodes() + { + // Add all nodes to the network + for (auto feature : features) { + model.addNode(feature, states[feature].size()); + } + model.addNode(className, states[className].size()); } } \ No newline at end of file diff --git a/src/BaseClassifier.h b/src/BaseClassifier.h index 15395d3..aca3066 100644 --- a/src/BaseClassifier.h +++ b/src/BaseClassifier.h @@ -1,6 +1,7 @@ #ifndef CLASSIFIERS_H #include #include "Network.h" +#include "Metrics.hpp" using namespace std; using namespace torch; @@ -14,6 +15,7 @@ namespace bayesnet { Tensor X; Tensor y; Tensor dataset; + Metrics metrics; vector features; string className; map> states; @@ -21,14 +23,13 @@ namespace bayesnet { virtual void train() = 0; public: BaseClassifier(Network model); - Tensor& getX(); - vector& getFeatures(); - string& getClassName(); BaseClassifier& fit(Tensor& X, Tensor& y, vector& features, string className, map>& states); BaseClassifier& fit(vector>& X, vector& y, vector& features, string className, map>& states); + void addNodes(); Tensor predict(Tensor& X); float score(Tensor& X, Tensor& y); - void show(); + vector show(); + vector argsort(vector& nums); }; } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d6caf2e..b76c503 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc) +add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc) target_link_libraries(BayesNet "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/KDB.cc b/src/KDB.cc index 5a8f11f..d28d805 100644 --- a/src/KDB.cc +++ b/src/KDB.cc @@ -1,17 +1,9 @@ #include "KDB.h" -#include "Metrics.hpp" namespace bayesnet { using namespace std; using namespace torch; - vector argsort(vector& nums) - { - int n = nums.size(); - vector indices(n); - iota(indices.begin(), indices.end(), 0); - sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];}); - return indices; - } + KDB::KDB(int k, float theta) : BaseClassifier(Network()), k(k), theta(theta) {} void KDB::train() { @@ -36,31 +28,23 @@ namespace bayesnet { */ // 1. For each feature Xi, compute mutual information, I(X;C), // where C is the class. - cout << "Computing mutual information between features and class" << endl; - auto n_classes = states[className].size(); - auto metrics = Metrics(dataset, features, className, n_classes); vector mi; for (auto i = 0; i < features.size(); i++) { Tensor firstFeature = X.index({ "...", i }); mi.push_back(metrics.mutualInformation(firstFeature, y)); - cout << "Mutual information between " << features[i] << " and " << className << " is " << mi[i] << endl; } // 2. Compute class conditional mutual information I(Xi;XjIC), f or each auto conditionalEdgeWeights = metrics.conditionalEdge(); - cout << "Conditional edge weights" << endl; - cout << conditionalEdgeWeights << endl; // 3. Let the used variable list, S, be empty. vector S; // 4. Let the DAG network being constructed, BN, begin with a single // class node, C. model.addNode(className, states[className].size()); - cout << "Adding node " << className << " to the network" << endl; // 5. Repeat until S includes all domain features // 5.1. Select feature Xmax which is not in S and has the largest value // I(Xmax;C). auto order = argsort(mi); for (auto idx : order) { - cout << idx << " " << mi[idx] << endl; // 5.2. Add a node to BN representing Xmax. model.addNode(features[idx], states[features[idx]].size()); // 5.3. Add an arc from C to Xmax in BN. @@ -76,8 +60,6 @@ namespace bayesnet { { auto n_edges = min(k, static_cast(S.size())); auto cond_w = clone(weights); - cout << "Conditional edge weights cloned for idx " << idx << endl; - cout << cond_w << endl; bool exit_cond = k == 0; int num = 0; while (!exit_cond) { @@ -93,22 +75,9 @@ namespace bayesnet { } } cond_w.index_put_({ idx, max_minfo }, -1); - cout << "Conditional edge weights cloned for idx " << idx << " After -1" << endl; - cout << cond_w << endl; - cout << "cond_w.index({ idx, '...'})" << endl; - cout << cond_w.index({ idx, "..." }) << endl; auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta); auto candidates = candidates_mask.nonzero(); - cout << "Candidates mask" << endl; - cout << candidates_mask << endl; - cout << "Candidates: " << endl; - cout << candidates << endl; - cout << "Candidates size: " << candidates.size(0) << endl; exit_cond = num == n_edges || candidates.size(0) == 0; } } - vector KDB::show() - { - return model.show(); - } } \ No newline at end of file diff --git a/src/KDB.h b/src/KDB.h index 930a125..a8d154c 100644 --- a/src/KDB.h +++ b/src/KDB.h @@ -13,7 +13,6 @@ namespace bayesnet { void train() override; public: KDB(int k, float theta = 0.03); - vector show(); }; } #endif \ No newline at end of file diff --git a/src/Metrics.cc b/src/Metrics.cc index f8174c1..b6b5d8b 100644 --- a/src/Metrics.cc +++ b/src/Metrics.cc @@ -116,4 +116,12 @@ namespace bayesnet { { return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature); } + vector> Metrics::maximumSpanningTree(Tensor& weights) + { + auto result = vector>(); + // Compute the maximum spanning tree considering the weights as distances + // and the indices of the weights as nodes of this square matrix + + return result; + } } \ No newline at end of file diff --git a/src/Metrics.hpp b/src/Metrics.hpp index f939e44..2934f83 100644 --- a/src/Metrics.hpp +++ b/src/Metrics.hpp @@ -3,23 +3,26 @@ #include #include #include -using namespace std; namespace bayesnet { + using namespace std; + using namespace torch; class Metrics { private: - torch::Tensor samples; + Tensor samples; vector features; string className; int classNumStates; - vector> doCombinations(const vector&); - double entropy(torch::Tensor&); - double conditionalEntropy(torch::Tensor&, torch::Tensor&); public: - double mutualInformation(torch::Tensor&, torch::Tensor&); - Metrics(torch::Tensor&, vector&, string&, int); + Metrics() = default; + Metrics(Tensor&, vector&, string&, int); Metrics(const vector>&, const vector&, const vector&, const string&, const int); + double entropy(Tensor&); + double conditionalEntropy(Tensor&, Tensor&); + double mutualInformation(Tensor&, Tensor&); vector conditionalEdgeWeights(); - torch::Tensor conditionalEdge(); + Tensor conditionalEdge(); + vector> doCombinations(const vector&); + vector> maximumSpanningTree(Tensor& weights); }; } #endif \ No newline at end of file diff --git a/src/TAN.cc b/src/TAN.cc new file mode 100644 index 0000000..d490a23 --- /dev/null +++ b/src/TAN.cc @@ -0,0 +1,25 @@ +#include "TAN.h" + +namespace bayesnet { + using namespace std; + using namespace torch; + + TAN::TAN() : BaseClassifier(Network()) {} + + void TAN::train() + { + // 0. Add all nodes to the model + addNodes(); + // 1. Compute mutual information between each feature and the class + auto weights = metrics.conditionalEdge(); + // 2. Compute the maximum spanning tree + auto mst = metrics.maximumSpanningTree(weights); + // 3. Add edges from the maximum spanning tree to the model + for (auto i = 0; i < mst.size(); ++i) { + auto [from, to] = mst[i]; + model.addEdge(features[from], features[to]); + } + + } + +} \ No newline at end of file diff --git a/src/TAN.h b/src/TAN.h new file mode 100644 index 0000000..d1477b6 --- /dev/null +++ b/src/TAN.h @@ -0,0 +1,15 @@ +#ifndef TAN_H +#define TAN_H +#include "BaseClassifier.h" +namespace bayesnet { + using namespace std; + using namespace torch; + class TAN : public BaseClassifier { + private: + protected: + void train() override; + public: + TAN(); + }; +} +#endif \ No newline at end of file