Refactor New classifiers to extract predict

This commit is contained in:
Ricardo Montañana Gómez 2023-08-05 18:39:48 +02:00
parent 1a09ccca4c
commit 7f45495837
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 35 additions and 50 deletions

View File

@ -10,15 +10,12 @@ namespace bayesnet {
className = className_;
Xf = X_;
y = y_;
model.initialize();
// Fills vectors Xv & yv with the data from tensors X_ (discretized) & y
fit_local_discretization(states, y);
generateTensorXFromVector();
// We have discretized the input data
// 1st we need to fit the model to build the normal TAN structure, TAN::fit initializes the base Bayesian network
cout << "KDBNew: Fitting model" << endl;
// 1st we need to fit the model to build the normal KDB structure, KDB::fit initializes the base Bayesian network
KDB::fit(KDB::Xv, KDB::yv, features, className, states);
cout << "KDBNew: Model fitted" << endl;
localDiscretizationProposal(states, model);
generateTensorXFromVector();
Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
@ -26,20 +23,10 @@ namespace bayesnet {
model.fit(KDB::Xv, KDB::yv, features, className);
return *this;
}
void KDBNew::train()
{
KDB::train();
}
Tensor KDBNew::predict(Tensor& X)
{
auto Xtd = torch::zeros_like(X, torch::kInt32);
for (int i = 0; i < X.size(0); ++i) {
auto Xt = vector<float>(X[i].data_ptr<float>(), X[i].data_ptr<float>() + X.size(1));
auto Xd = discretizers[features[i]]->transform(Xt);
Xtd.index_put_({ i }, torch::tensor(Xd, torch::kInt32));
}
cout << "KDBNew Xtd: " << Xtd.sizes() << endl;
return KDB::predict(Xtd);
auto Xt = prepareX(X);
return KDB::predict(Xt);
}
vector<string> KDBNew::graph(const string& name)
{

View File

@ -13,7 +13,6 @@ namespace bayesnet {
KDBNew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
vector<string> graph(const string& name = "KDB") override;
Tensor predict(Tensor& X) override;
void train() override;
static inline string version() { return "0.0.1"; };
};
}

View File

@ -47,25 +47,25 @@ namespace bayesnet {
//
//
//
auto tmp = discretizers[feature]->transform(xvf);
Xv[index] = tmp;
auto xStates = vector<int>(discretizers[pFeatures[index]]->getCutPoints().size() + 1);
iota(xStates.begin(), xStates.end(), 0);
//Update new states of the feature/node
states[feature] = xStates;
// auto tmp = discretizers[feature]->transform(xvf);
// Xv[index] = tmp;
// auto xStates = vector<int>(discretizers[pFeatures[index]]->getCutPoints().size() + 1);
// iota(xStates.begin(), xStates.end(), 0);
// //Update new states of the feature/node
// states[feature] = xStates;
}
if (upgrade) {
// Discretize again X (only the affected indices) with the new fitted discretizers
for (auto index : indicesToReDiscretize) {
auto Xt_ptr = Xf.index({ index }).data_ptr<float>();
auto Xt = vector<float>(Xt_ptr, Xt_ptr + Xf.size(1));
Xv[index] = discretizers[pFeatures[index]]->transform(Xt);
auto xStates = vector<int>(discretizers[pFeatures[index]]->getCutPoints().size() + 1);
iota(xStates.begin(), xStates.end(), 0);
//Update new states of the feature/node
states[pFeatures[index]] = xStates;
}
}
// if (upgrade) {
// // Discretize again X (only the affected indices) with the new fitted discretizers
// for (auto index : indicesToReDiscretize) {
// auto Xt_ptr = Xf.index({ index }).data_ptr<float>();
// auto Xt = vector<float>(Xt_ptr, Xt_ptr + Xf.size(1));
// Xv[index] = discretizers[pFeatures[index]]->transform(Xt);
// auto xStates = vector<int>(discretizers[pFeatures[index]]->getCutPoints().size() + 1);
// iota(xStates.begin(), xStates.end(), 0);
// //Update new states of the feature/node
// states[pFeatures[index]] = xStates;
// }
// }
}
void Proposal::fit_local_discretization(map<string, vector<int>>& states, torch::Tensor& y)
{
@ -89,4 +89,14 @@ namespace bayesnet {
iota(yStates.begin(), yStates.end(), 0);
states[pClassName] = yStates;
}
torch::Tensor Proposal::prepareX(torch::Tensor& X)
{
auto Xtd = torch::zeros_like(X, torch::kInt32);
for (int i = 0; i < X.size(0); ++i) {
auto Xt = vector<float>(X[i].data_ptr<float>(), X[i].data_ptr<float>() + X.size(1));
auto Xd = discretizers[pFeatures[i]]->transform(Xt);
Xtd.index_put_({ i }, torch::tensor(Xd, torch::kInt32));
}
return Xtd;
}
}

View File

@ -5,6 +5,7 @@
#include <torch/torch.h>
#include "Network.h"
#include "CPPFImdlp.h"
#include "Classifier.h"
namespace bayesnet {
class Proposal {
@ -12,6 +13,7 @@ namespace bayesnet {
Proposal(vector<vector<int>>& Xv_, vector<int>& yv_, vector<string>& features_, string& className_);
virtual ~Proposal();
protected:
torch::Tensor prepareX(torch::Tensor& X);
void localDiscretizationProposal(map<string, vector<int>>& states, Network& model);
void fit_local_discretization(map<string, vector<int>>& states, torch::Tensor& y);
torch::Tensor Xf; // X continuous nxm tensor

View File

@ -15,9 +15,7 @@ namespace bayesnet {
generateTensorXFromVector();
// We have discretized the input data
// 1st we need to fit the model to build the normal TAN structure, TAN::fit initializes the base Bayesian network
cout << "TANNew: Fitting model" << endl;
TAN::fit(TAN::Xv, TAN::yv, features, className, states);
cout << "TANNew: Model fitted" << endl;
localDiscretizationProposal(states, model);
generateTensorXFromVector();
Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
@ -27,14 +25,8 @@ namespace bayesnet {
}
Tensor TANNew::predict(Tensor& X)
{
auto Xtd = torch::zeros_like(X, torch::kInt32);
for (int i = 0; i < X.size(0); ++i) {
auto Xt = vector<float>(X[i].data_ptr<float>(), X[i].data_ptr<float>() + X.size(1));
auto Xd = discretizers[features[i]]->transform(Xt);
Xtd.index_put_({ i }, torch::tensor(Xd, torch::kInt32));
}
cout << "TANNew Xtd: " << Xtd.sizes() << endl;
return TAN::predict(Xtd);
auto Xt = prepareX(X);
return TAN::predict(Xt);
}
vector<string> TANNew::graph(const string& name)
{

View File

@ -146,11 +146,6 @@ namespace platform {
auto y_test = y.index({ test_t });
cout << nfold + 1 << ", " << flush;
clf->fit(X_train, y_train, features, className, states);
cout << endl;
auto lines = clf->show();
for (auto line : lines) {
cout << line << endl;
}
nodes[item] = clf->getNumberOfNodes();
edges[item] = clf->getNumberOfEdges();
num_states[item] = clf->getNumberOfStates();