Add SPODENew to models

This commit is contained in:
Ricardo Montañana Gómez 2023-08-05 23:11:36 +02:00
parent 506ef34c6f
commit 1d0fd629c9
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
7 changed files with 63 additions and 5 deletions

View File

@ -1,4 +1,5 @@
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANNew.cc KDBNew.cc Mst.cc Proposal.cc)
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc
KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANNew.cc KDBNew.cc SPODENew.cc Mst.cc Proposal.cc)
target_link_libraries(BayesNet mdlp ArffFiles "${TORCH_LIBRARIES}")

View File

@ -8,7 +8,7 @@ namespace bayesnet {
class KDBNew : public KDB, public Proposal {
private:
public:
KDBNew(int k);
explicit KDBNew(int k);
virtual ~KDBNew() = default;
KDBNew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
vector<string> graph(const string& name = "KDB") override;

35
src/BayesNet/SPODENew.cc Normal file
View File

@ -0,0 +1,35 @@
#include "SPODENew.h"
namespace bayesnet {
using namespace std;
SPODENew::SPODENew(int root) : SPODE(root), Proposal(SPODE::Xv, SPODE::yv, features, className) {}
SPODENew& SPODENew::fit(torch::Tensor& X_, torch::Tensor& y_, vector<string>& features_, string className_, map<string, vector<int>>& states_)
{
// This first part should go in a Classifier method called fit_local_discretization o fit_float...
features = features_;
className = className_;
Xf = X_;
y = y_;
// Fills vectors Xv & yv with the data from tensors X_ (discretized) & y
fit_local_discretization(states, y);
generateTensorXFromVector();
// We have discretized the input data
// 1st we need to fit the model to build the normal SPODE structure, SPODE::fit initializes the base Bayesian network
SPODE::fit(SPODE::Xv, SPODE::yv, features, className, states);
localDiscretizationProposal(states, model);
generateTensorXFromVector();
Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
samples = torch::cat({ X, ytmp }, 0);
model.fit(SPODE::Xv, SPODE::yv, features, className);
return *this;
}
Tensor SPODENew::predict(Tensor& X)
{
auto Xt = prepareX(X);
return SPODE::predict(Xt);
}
vector<string> SPODENew::graph(const string& name)
{
return SPODE::graph(name);
}
}

19
src/BayesNet/SPODENew.h Normal file
View File

@ -0,0 +1,19 @@
#ifndef SPODENEW_H
#define SPODENEW_H
#include "SPODE.h"
#include "Proposal.h"
namespace bayesnet {
using namespace std;
class SPODENew : public SPODE, public Proposal {
private:
public:
explicit SPODENew(int root);
virtual ~SPODENew() = default;
SPODENew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
vector<string> graph(const string& name = "SPODE") override;
Tensor predict(Tensor& X) override;
static inline string version() { return "0.0.1"; };
};
}
#endif // !SPODENew_H

View File

@ -8,6 +8,7 @@
#include "SPODE.h"
#include "TANNew.h"
#include "KDBNew.h"
#include "SPODENew.h"
namespace platform {
class Models {
private:

View File

@ -48,9 +48,9 @@ namespace platform {
cout << setw(6) << right << r["samples"].get<int>() << " ";
cout << setw(5) << right << r["features"].get<int>() << " ";
cout << setw(3) << right << r["classes"].get<int>() << " ";
cout << setw(7) << right << r["nodes"].get<float>() << " ";
cout << setw(7) << right << r["leaves"].get<float>() << " ";
cout << setw(7) << right << r["depth"].get<float>() << " ";
cout << setw(7) << setprecision(2) << fixed << r["nodes"].get<float>() << " ";
cout << setw(7) << setprecision(2) << fixed << r["leaves"].get<float>() << " ";
cout << setw(7) << setprecision(2) << fixed << r["depth"].get<float>() << " ";
cout << setw(8) << right << setprecision(6) << fixed << r["score_test"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_test_std"].get<double>() << " ";
cout << setw(10) << right << setprecision(6) << fixed << r["test_time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["test_time_std"].get<double>() << " ";
cout << " " << r["hyperparameters"].get<string>();

View File

@ -6,6 +6,8 @@ static platform::Registrar registrarTN("TANNew",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TANNew();});
static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarSN("SPODENew",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODENew(2);});
static platform::Registrar registrarK("KDB",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
static platform::Registrar registrarKN("KDBNew",