Refactor BaseClassifier and begin TAN impl.

This commit is contained in:
Ricardo Montañana Gómez 2023-07-14 00:10:55 +02:00
parent e52fdc718f
commit 3f09d474f9
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
9 changed files with 87 additions and 49 deletions

View File

@ -4,7 +4,7 @@ namespace bayesnet {
using namespace std;
using namespace torch;
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0) {}
BaseClassifier::BaseClassifier(Network model) : model(model), m(0), n(0), metrics(Metrics()) {}
BaseClassifier& BaseClassifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
{
@ -13,6 +13,8 @@ namespace bayesnet {
this->className = className;
this->states = states;
checkFitParameters();
auto n_classes = states[className].size();
metrics = Metrics(dataset, features, className, n_classes);
train();
return *this;
}
@ -51,6 +53,14 @@ namespace bayesnet {
}
}
}
vector<int> BaseClassifier::argsort(vector<float>& nums)
{
int n = nums.size();
vector<int> indices(n);
iota(indices.begin(), indices.end(), 0);
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
return indices;
}
vector<vector<int>> tensorToVector(const torch::Tensor& tensor)
{
// convert mxn tensor to nxm vector
@ -86,8 +96,16 @@ namespace bayesnet {
Tensor y_pred = predict(X);
return (y_pred == y).sum().item<float>() / y.size(0);
}
void BaseClassifier::show()
vector<string> BaseClassifier::show()
{
model.show();
return model.show();
}
void BaseClassifier::addNodes()
{
// Add all nodes to the network
for (auto feature : features) {
model.addNode(feature, states[feature].size());
}
model.addNode(className, states[className].size());
}
}

View File

@ -1,6 +1,7 @@
#ifndef CLASSIFIERS_H
#include <torch/torch.h>
#include "Network.h"
#include "Metrics.hpp"
using namespace std;
using namespace torch;
@ -14,6 +15,7 @@ namespace bayesnet {
Tensor X;
Tensor y;
Tensor dataset;
Metrics metrics;
vector<string> features;
string className;
map<string, vector<int>> states;
@ -21,14 +23,13 @@ namespace bayesnet {
virtual void train() = 0;
public:
BaseClassifier(Network model);
Tensor& getX();
vector<string>& getFeatures();
string& getClassName();
BaseClassifier& fit(Tensor& X, Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states);
BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states);
void addNodes();
Tensor predict(Tensor& X);
float score(Tensor& X, Tensor& y);
void show();
vector<string> show();
vector<int> argsort(vector<float>& nums);
};
}
#endif

View File

@ -1,2 +1,2 @@
add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc)
add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")

View File

@ -1,17 +1,9 @@
#include "KDB.h"
#include "Metrics.hpp"
namespace bayesnet {
using namespace std;
using namespace torch;
vector<int> argsort(vector<float>& nums)
{
int n = nums.size();
vector<int> indices(n);
iota(indices.begin(), indices.end(), 0);
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
return indices;
}
KDB::KDB(int k, float theta) : BaseClassifier(Network()), k(k), theta(theta) {}
void KDB::train()
{
@ -36,31 +28,23 @@ namespace bayesnet {
*/
// 1. For each feature Xi, compute mutual information, I(X;C),
// where C is the class.
cout << "Computing mutual information between features and class" << endl;
auto n_classes = states[className].size();
auto metrics = Metrics(dataset, features, className, n_classes);
vector <float> mi;
for (auto i = 0; i < features.size(); i++) {
Tensor firstFeature = X.index({ "...", i });
mi.push_back(metrics.mutualInformation(firstFeature, y));
cout << "Mutual information between " << features[i] << " and " << className << " is " << mi[i] << endl;
}
// 2. Compute class conditional mutual information I(Xi;XjIC), f or each
auto conditionalEdgeWeights = metrics.conditionalEdge();
cout << "Conditional edge weights" << endl;
cout << conditionalEdgeWeights << endl;
// 3. Let the used variable list, S, be empty.
vector<int> S;
// 4. Let the DAG network being constructed, BN, begin with a single
// class node, C.
model.addNode(className, states[className].size());
cout << "Adding node " << className << " to the network" << endl;
// 5. Repeat until S includes all domain features
// 5.1. Select feature Xmax which is not in S and has the largest value
// I(Xmax;C).
auto order = argsort(mi);
for (auto idx : order) {
cout << idx << " " << mi[idx] << endl;
// 5.2. Add a node to BN representing Xmax.
model.addNode(features[idx], states[features[idx]].size());
// 5.3. Add an arc from C to Xmax in BN.
@ -76,8 +60,6 @@ namespace bayesnet {
{
auto n_edges = min(k, static_cast<int>(S.size()));
auto cond_w = clone(weights);
cout << "Conditional edge weights cloned for idx " << idx << endl;
cout << cond_w << endl;
bool exit_cond = k == 0;
int num = 0;
while (!exit_cond) {
@ -93,22 +75,9 @@ namespace bayesnet {
}
}
cond_w.index_put_({ idx, max_minfo }, -1);
cout << "Conditional edge weights cloned for idx " << idx << " After -1" << endl;
cout << cond_w << endl;
cout << "cond_w.index({ idx, '...'})" << endl;
cout << cond_w.index({ idx, "..." }) << endl;
auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta);
auto candidates = candidates_mask.nonzero();
cout << "Candidates mask" << endl;
cout << candidates_mask << endl;
cout << "Candidates: " << endl;
cout << candidates << endl;
cout << "Candidates size: " << candidates.size(0) << endl;
exit_cond = num == n_edges || candidates.size(0) == 0;
}
}
vector<string> KDB::show()
{
return model.show();
}
}

View File

@ -13,7 +13,6 @@ namespace bayesnet {
void train() override;
public:
KDB(int k, float theta = 0.03);
vector<string> show();
};
}
#endif

View File

@ -116,4 +116,12 @@ namespace bayesnet {
{
return entropy(firstFeature) - conditionalEntropy(firstFeature, secondFeature);
}
vector<pair<int, int>> Metrics::maximumSpanningTree(Tensor& weights)
{
auto result = vector<pair<int, int>>();
// Compute the maximum spanning tree considering the weights as distances
// and the indices of the weights as nodes of this square matrix
return result;
}
}

View File

@ -3,23 +3,26 @@
#include <torch/torch.h>
#include <vector>
#include <string>
using namespace std;
namespace bayesnet {
using namespace std;
using namespace torch;
class Metrics {
private:
torch::Tensor samples;
Tensor samples;
vector<string> features;
string className;
int classNumStates;
vector<pair<string, string>> doCombinations(const vector<string>&);
double entropy(torch::Tensor&);
double conditionalEntropy(torch::Tensor&, torch::Tensor&);
public:
double mutualInformation(torch::Tensor&, torch::Tensor&);
Metrics(torch::Tensor&, vector<string>&, string&, int);
Metrics() = default;
Metrics(Tensor&, vector<string>&, string&, int);
Metrics(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&, const int);
double entropy(Tensor&);
double conditionalEntropy(Tensor&, Tensor&);
double mutualInformation(Tensor&, Tensor&);
vector<float> conditionalEdgeWeights();
torch::Tensor conditionalEdge();
Tensor conditionalEdge();
vector<pair<string, string>> doCombinations(const vector<string>&);
vector<pair<int, int>> maximumSpanningTree(Tensor& weights);
};
}
#endif

25
src/TAN.cc Normal file
View File

@ -0,0 +1,25 @@
#include "TAN.h"
namespace bayesnet {
using namespace std;
using namespace torch;
TAN::TAN() : BaseClassifier(Network()) {}
void TAN::train()
{
// 0. Add all nodes to the model
addNodes();
// 1. Compute mutual information between each feature and the class
auto weights = metrics.conditionalEdge();
// 2. Compute the maximum spanning tree
auto mst = metrics.maximumSpanningTree(weights);
// 3. Add edges from the maximum spanning tree to the model
for (auto i = 0; i < mst.size(); ++i) {
auto [from, to] = mst[i];
model.addEdge(features[from], features[to]);
}
}
}

15
src/TAN.h Normal file
View File

@ -0,0 +1,15 @@
#ifndef TAN_H
#define TAN_H
#include "BaseClassifier.h"
namespace bayesnet {
using namespace std;
using namespace torch;
class TAN : public BaseClassifier {
private:
protected:
void train() override;
public:
TAN();
};
}
#endif