diff --git a/src/BayesNet/KDBNew.cc b/src/BayesNet/KDBNew.cc index 8adf3b4..f2f6e46 100644 --- a/src/BayesNet/KDBNew.cc +++ b/src/BayesNet/KDBNew.cc @@ -10,15 +10,12 @@ namespace bayesnet { className = className_; Xf = X_; y = y_; - model.initialize(); // Fills vectors Xv & yv with the data from tensors X_ (discretized) & y fit_local_discretization(states, y); generateTensorXFromVector(); // We have discretized the input data - // 1st we need to fit the model to build the normal TAN structure, TAN::fit initializes the base Bayesian network - cout << "KDBNew: Fitting model" << endl; + // 1st we need to fit the model to build the normal KDB structure, KDB::fit initializes the base Bayesian network KDB::fit(KDB::Xv, KDB::yv, features, className, states); - cout << "KDBNew: Model fitted" << endl; localDiscretizationProposal(states, model); generateTensorXFromVector(); Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1); @@ -26,20 +23,10 @@ namespace bayesnet { model.fit(KDB::Xv, KDB::yv, features, className); return *this; } - void KDBNew::train() - { - KDB::train(); - } Tensor KDBNew::predict(Tensor& X) { - auto Xtd = torch::zeros_like(X, torch::kInt32); - for (int i = 0; i < X.size(0); ++i) { - auto Xt = vector(X[i].data_ptr(), X[i].data_ptr() + X.size(1)); - auto Xd = discretizers[features[i]]->transform(Xt); - Xtd.index_put_({ i }, torch::tensor(Xd, torch::kInt32)); - } - cout << "KDBNew Xtd: " << Xtd.sizes() << endl; - return KDB::predict(Xtd); + auto Xt = prepareX(X); + return KDB::predict(Xt); } vector KDBNew::graph(const string& name) { diff --git a/src/BayesNet/KDBNew.h b/src/BayesNet/KDBNew.h index 2bcd213..63ae9e9 100644 --- a/src/BayesNet/KDBNew.h +++ b/src/BayesNet/KDBNew.h @@ -13,7 +13,6 @@ namespace bayesnet { KDBNew& fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) override; vector graph(const string& name = "KDB") override; Tensor predict(Tensor& X) override; - void train() override; static inline string version() { return "0.0.1"; }; }; } diff --git a/src/BayesNet/Proposal.cc b/src/BayesNet/Proposal.cc index 2156b9a..c1d3626 100644 --- a/src/BayesNet/Proposal.cc +++ b/src/BayesNet/Proposal.cc @@ -47,25 +47,25 @@ namespace bayesnet { // // // - auto tmp = discretizers[feature]->transform(xvf); - Xv[index] = tmp; - auto xStates = vector(discretizers[pFeatures[index]]->getCutPoints().size() + 1); - iota(xStates.begin(), xStates.end(), 0); - //Update new states of the feature/node - states[feature] = xStates; + // auto tmp = discretizers[feature]->transform(xvf); + // Xv[index] = tmp; + // auto xStates = vector(discretizers[pFeatures[index]]->getCutPoints().size() + 1); + // iota(xStates.begin(), xStates.end(), 0); + // //Update new states of the feature/node + // states[feature] = xStates; + } + if (upgrade) { + // Discretize again X (only the affected indices) with the new fitted discretizers + for (auto index : indicesToReDiscretize) { + auto Xt_ptr = Xf.index({ index }).data_ptr(); + auto Xt = vector(Xt_ptr, Xt_ptr + Xf.size(1)); + Xv[index] = discretizers[pFeatures[index]]->transform(Xt); + auto xStates = vector(discretizers[pFeatures[index]]->getCutPoints().size() + 1); + iota(xStates.begin(), xStates.end(), 0); + //Update new states of the feature/node + states[pFeatures[index]] = xStates; + } } - // if (upgrade) { - // // Discretize again X (only the affected indices) with the new fitted discretizers - // for (auto index : indicesToReDiscretize) { - // auto Xt_ptr = Xf.index({ index }).data_ptr(); - // auto Xt = vector(Xt_ptr, Xt_ptr + Xf.size(1)); - // Xv[index] = discretizers[pFeatures[index]]->transform(Xt); - // auto xStates = vector(discretizers[pFeatures[index]]->getCutPoints().size() + 1); - // iota(xStates.begin(), xStates.end(), 0); - // //Update new states of the feature/node - // states[pFeatures[index]] = xStates; - // } - // } } void Proposal::fit_local_discretization(map>& states, torch::Tensor& y) { @@ -89,4 +89,14 @@ namespace bayesnet { iota(yStates.begin(), yStates.end(), 0); states[pClassName] = yStates; } + torch::Tensor Proposal::prepareX(torch::Tensor& X) + { + auto Xtd = torch::zeros_like(X, torch::kInt32); + for (int i = 0; i < X.size(0); ++i) { + auto Xt = vector(X[i].data_ptr(), X[i].data_ptr() + X.size(1)); + auto Xd = discretizers[pFeatures[i]]->transform(Xt); + Xtd.index_put_({ i }, torch::tensor(Xd, torch::kInt32)); + } + return Xtd; + } } \ No newline at end of file diff --git a/src/BayesNet/Proposal.h b/src/BayesNet/Proposal.h index 981c22c..a5650b4 100644 --- a/src/BayesNet/Proposal.h +++ b/src/BayesNet/Proposal.h @@ -5,6 +5,7 @@ #include #include "Network.h" #include "CPPFImdlp.h" +#include "Classifier.h" namespace bayesnet { class Proposal { @@ -12,6 +13,7 @@ namespace bayesnet { Proposal(vector>& Xv_, vector& yv_, vector& features_, string& className_); virtual ~Proposal(); protected: + torch::Tensor prepareX(torch::Tensor& X); void localDiscretizationProposal(map>& states, Network& model); void fit_local_discretization(map>& states, torch::Tensor& y); torch::Tensor Xf; // X continuous nxm tensor diff --git a/src/BayesNet/TANNew.cc b/src/BayesNet/TANNew.cc index e0e1f0b..15a1eaf 100644 --- a/src/BayesNet/TANNew.cc +++ b/src/BayesNet/TANNew.cc @@ -15,9 +15,7 @@ namespace bayesnet { generateTensorXFromVector(); // We have discretized the input data // 1st we need to fit the model to build the normal TAN structure, TAN::fit initializes the base Bayesian network - cout << "TANNew: Fitting model" << endl; TAN::fit(TAN::Xv, TAN::yv, features, className, states); - cout << "TANNew: Model fitted" << endl; localDiscretizationProposal(states, model); generateTensorXFromVector(); Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1); @@ -27,14 +25,8 @@ namespace bayesnet { } Tensor TANNew::predict(Tensor& X) { - auto Xtd = torch::zeros_like(X, torch::kInt32); - for (int i = 0; i < X.size(0); ++i) { - auto Xt = vector(X[i].data_ptr(), X[i].data_ptr() + X.size(1)); - auto Xd = discretizers[features[i]]->transform(Xt); - Xtd.index_put_({ i }, torch::tensor(Xd, torch::kInt32)); - } - cout << "TANNew Xtd: " << Xtd.sizes() << endl; - return TAN::predict(Xtd); + auto Xt = prepareX(X); + return TAN::predict(Xt); } vector TANNew::graph(const string& name) { diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 261b1c5..a79216c 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -146,11 +146,6 @@ namespace platform { auto y_test = y.index({ test_t }); cout << nfold + 1 << ", " << flush; clf->fit(X_train, y_train, features, className, states); - cout << endl; - auto lines = clf->show(); - for (auto line : lines) { - cout << line << endl; - } nodes[item] = clf->getNumberOfNodes(); edges[item] = clf->getNumberOfEdges(); num_states[item] = clf->getNumberOfStates();