Refactor BaseClassifier and begin TAN impl.

This commit is contained in:
Ricardo Montañana Gómez 2023-07-14 00:10:55 +02:00
parent e52fdc718f
commit 3f09d474f9
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
9 changed files with 87 additions and 49 deletions

View File

@ -4,7 +4,7 @@ namespace bayesnet {
using namespace std; using namespace std;
using namespace torch; using namespace torch;
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0) {} BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()) {}
BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states) BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
{ {
@ -13,6 +13,8 @@ namespace bayesnet {
this->className = className; this->className = className;
this->states = states; this->states = states;
checkFitParameters(); checkFitParameters();
auto n_classes = states[className].size();
metrics = Metrics(dataset, features, className, n_classes);
train(); train();
return *this; return *this;
} }
@ -51,6 +53,14 @@ namespace bayesnet {
} }
} }
} }
vector<int> BaseClassifier::argsort(vector<float>& nums)
{
int n = nums.size();
vector<int> indices(n);
iota(indices.begin(), indices.end(), 0);
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
return indices;
}
vector<vector<int>> tensorToVector(const torch::Tensor& tensor) vector<vector<int>> tensorToVector(const torch::Tensor& tensor)
{ {
// convert mxn tensor to nxm vector // convert mxn tensor to nxm vector
@ -86,8 +96,16 @@ namespace bayesnet {
Tensor y_pred = predict(X); Tensor y_pred = predict(X);
return (y_pred == y).sum().item<float>() / y.size(0); return (y_pred == y).sum().item<float>() / y.size(0);
} }
void BaseClassifier::show() vector<string> BaseClassifier::show()
{ {
model.show(); return model.show();
}
void BaseClassifier::addNodes()
{
// Add all nodes to the network
for (auto feature : features) {
model.addNode(feature, states[feature].size());
}
model.addNode(className, states[className].size());
} }
} }

View File

@ -1,6 +1,7 @@
#ifndef CLASSIFIERS_H #ifndef CLASSIFIERS_H
#include <torch/torch.h> #include <torch/torch.h>
#include "Network.h" #include "Network.h"
#include "Metrics.hpp"
using namespace std; using namespace std;
using namespace torch; using namespace torch;
@ -14,6 +15,7 @@ namespace bayesnet {
Tensor X; Tensor X;
Tensor y; Tensor y;
Tensor dataset; Tensor dataset;
Metrics metrics;
vector<string> features; vector<string> features;
string className; string className;
map<string, vector<int>> states; map<string, vector<int>> states;
@ -21,14 +23,13 @@ namespace bayesnet {
virtual void train() = 0; virtual void train() = 0;
public: public:
BaseClassifier(Network model); BaseClassifier(Network model);
Tensor& getX();
vector<string>& getFeatures();
string& getClassName();
BaseClassifier& fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states); BaseClassifier& fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states);
BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states); BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
void addNodes();
Tensor predict(Tensor& X); Tensor predict(Tensor& X);
float score(Tensor& X, Tensor& y); float score(Tensor& X, Tensor& y);
void show(); vector<string> show();
vector<int> argsort(vector<float>& nums);
}; };
} }
#endif #endif

View File

@ -1,2 +1,2 @@
add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc) add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}") target_link_libraries(BayesNet "${TORCH_LIBRARIES}")

View File

@ -1,17 +1,9 @@
#include "KDB.h" #include "KDB.h"
#include "Metrics.hpp"
namespace bayesnet { namespace bayesnet {
using namespace std; using namespace std;
using namespace torch; using namespace torch;
vector<int> argsort(vector<float>& nums)
{
int n = nums.size();
vector<int> indices(n);
iota(indices.begin(), indices.end(), 0);
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
return indices;
}
KDB::KDB(int k, float theta) : BaseClassifier(Network()), k(k), theta(theta) {} KDB::KDB(int k, float theta) : BaseClassifier(Network()), k(k), theta(theta) {}
void KDB::train() void KDB::train()
{ {
@ -36,31 +28,23 @@ namespace bayesnet {
*/ */
// 1. For each feature Xi, compute mutual information, I(X;C), // 1. For each feature Xi, compute mutual information, I(X;C),
// where C is the class. // where C is the class.
cout << "Computing mutual information between features and class" << endl;
auto n_classes = states[className].size();
auto metrics = Metrics(dataset, features, className, n_classes);
vector <float> mi; vector <float> mi;
for (auto i = 0; i < features.size(); i++) { for (auto i = 0; i < features.size(); i++) {
Tensor firstFeature = X.index({ "...", i }); Tensor firstFeature = X.index({ "...", i });
mi.push_back(metrics.mutualInformation(firstFeature, y)); mi.push_back(metrics.mutualInformation(firstFeature, y));
cout << "Mutual information between " << features[i] << " and " << className << " is " << mi[i] << endl;
} }
// 2. Compute class conditional mutual information I(Xi;XjIC), f or each // 2. Compute class conditional mutual information I(Xi;XjIC), f or each
auto conditionalEdgeWeights = metrics.conditionalEdge(); auto conditionalEdgeWeights = metrics.conditionalEdge();
cout << "Conditional edge weights" << endl;
cout << conditionalEdgeWeights << endl;
// 3. Let the used variable list, S, be empty. // 3. Let the used variable list, S, be empty.
vector<int> S; vector<int> S;
// 4. Let the DAG network being constructed, BN, begin with a single // 4. Let the DAG network being constructed, BN, begin with a single
// class node, C. // class node, C.
model.addNode(className, states[className].size()); model.addNode(className, states[className].size());
cout << "Adding node " << className << " to the network" << endl;
// 5. Repeat until S includes all domain features // 5. Repeat until S includes all domain features
// 5.1. Select feature Xmax which is not in S and has the largest value // 5.1. Select feature Xmax which is not in S and has the largest value
// I(Xmax;C). // I(Xmax;C).
auto order = argsort(mi); auto order = argsort(mi);
for (auto idx : order) { for (auto idx : order) {
cout << idx << " " << mi[idx] << endl;
// 5.2. Add a node to BN representing Xmax. // 5.2. Add a node to BN representing Xmax.
model.addNode(features[idx], states[features[idx]].size()); model.addNode(features[idx], states[features[idx]].size());
// 5.3. Add an arc from C to Xmax in BN. // 5.3. Add an arc from C to Xmax in BN.
@ -76,8 +60,6 @@ namespace bayesnet {
{ {
auto n_edges = min(k, static_cast<int>(S.size())); auto n_edges = min(k, static_cast<int>(S.size()));
auto cond_w = clone(weights); auto cond_w = clone(weights);
cout << "Conditional edge weights cloned for idx " << idx << endl;
cout << cond_w << endl;
bool exit_cond = k == 0; bool exit_cond = k == 0;
int num = 0; int num = 0;
while (!exit_cond) { while (!exit_cond) {
@ -93,22 +75,9 @@ namespace bayesnet {
} }
} }
cond_w.index_put_({ idx, max_minfo }, -1); cond_w.index_put_({ idx, max_minfo }, -1);
cout << "Conditional edge weights cloned for idx " << idx << " After -1" << endl;
cout << cond_w << endl;
cout << "cond_w.index({ idx, '...'})" << endl;
cout << cond_w.index({ idx, "..." }) << endl;
auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta); auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta);
auto candidates = candidates_mask.nonzero(); auto candidates = candidates_mask.nonzero();
cout << "Candidates mask" << endl;
cout << candidates_mask << endl;
cout << "Candidates: " << endl;
cout << candidates << endl;
cout << "Candidates size: " << candidates.size(0) << endl;
exit_cond = num == n_edges || candidates.size(0) == 0; exit_cond = num == n_edges || candidates.size(0) == 0;
} }
} }
vector<string> KDB::show()
{
return model.show();
}
} }

View File

@ -13,7 +13,6 @@ namespace bayesnet {
void train() override; void train() override;
public: public:
KDB(int k, float theta = 0.03); KDB(int k, float theta = 0.03);
vector<string> show();
}; };
} }
#endif #endif

View File

@ -116,4 +116,12 @@ namespace bayesnet {
{ {
return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature); return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature);
} }
vector<pair<int, int>> Metrics::maximumSpanningTree(Tensor& weights)
{
auto result = vector<pair<int, int>>();
// Compute the maximum spanning tree considering the weights as distances
// and the indices of the weights as nodes of this square matrix
return result;
}
} }

View File

@ -3,23 +3,26 @@
#include <torch/torch.h> #include <torch/torch.h>
#include <vector> #include <vector>
#include <string> #include <string>
using namespace std;
namespace bayesnet { namespace bayesnet {
using namespace std;
using namespace torch;
class Metrics { class Metrics {
private: private:
torch::Tensor samples; Tensor samples;
vector<string> features; vector<string> features;
string className; string className;
int classNumStates; int classNumStates;
vector<pair<string, string>> doCombinations(const vector<string>&);
double entropy(torch::Tensor&);
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
public: public:
double mutualInformation(torch::Tensor&, torch::Tensor&); Metrics() = default;
Metrics(torch::Tensor&, vector<string>&, string&, int); Metrics(Tensor&, vector<string>&, string&, int);
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int); Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
double entropy(Tensor&);
double conditionalEntropy(Tensor&, Tensor&);
double mutualInformation(Tensor&, Tensor&);
vector<float> conditionalEdgeWeights(); vector<float> conditionalEdgeWeights();
torch::Tensor conditionalEdge(); Tensor conditionalEdge();
vector<pair<string, string>> doCombinations(const vector<string>&);
vector<pair<int, int>> maximumSpanningTree(Tensor& weights);
}; };
} }
#endif #endif

25
src/TAN.cc Normal file
View File

@ -0,0 +1,25 @@
#include "TAN.h"
namespace bayesnet {
using namespace std;
using namespace torch;
TAN::TAN() : BaseClassifier(Network()) {}
void TAN::train()
{
// 0. Add all nodes to the model
addNodes();
// 1. Compute mutual information between each feature and the class
auto weights = metrics.conditionalEdge();
// 2. Compute the maximum spanning tree
auto mst = metrics.maximumSpanningTree(weights);
// 3. Add edges from the maximum spanning tree to the model
for (auto i = 0; i < mst.size(); ++i) {
auto [from, to] = mst[i];
model.addEdge(features[from], features[to]);
}
}
}

15
src/TAN.h Normal file
View File

@ -0,0 +1,15 @@
#ifndef TAN_H
#define TAN_H
#include "BaseClassifier.h"
namespace bayesnet {
using namespace std;
using namespace torch;
class TAN : public BaseClassifier {
private:
protected:
void train() override;
public:
TAN();
};
}
#endif