Add SPODENew to models
This commit is contained in:
parent
506ef34c6f
commit
1d0fd629c9
@ -1,4 +1,5 @@
|
|||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
|
||||||
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANNew.cc KDBNew.cc Mst.cc Proposal.cc)
|
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc
|
||||||
|
KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANNew.cc KDBNew.cc SPODENew.cc Mst.cc Proposal.cc)
|
||||||
target_link_libraries(BayesNet mdlp ArffFiles "${TORCH_LIBRARIES}")
|
target_link_libraries(BayesNet mdlp ArffFiles "${TORCH_LIBRARIES}")
|
@ -8,7 +8,7 @@ namespace bayesnet {
|
|||||||
class KDBNew : public KDB, public Proposal {
|
class KDBNew : public KDB, public Proposal {
|
||||||
private:
|
private:
|
||||||
public:
|
public:
|
||||||
KDBNew(int k);
|
explicit KDBNew(int k);
|
||||||
virtual ~KDBNew() = default;
|
virtual ~KDBNew() = default;
|
||||||
KDBNew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
KDBNew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||||
vector<string> graph(const string& name = "KDB") override;
|
vector<string> graph(const string& name = "KDB") override;
|
||||||
|
35
src/BayesNet/SPODENew.cc
Normal file
35
src/BayesNet/SPODENew.cc
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#include "SPODENew.h"
|
||||||
|
|
||||||
|
namespace bayesnet {
|
||||||
|
using namespace std;
|
||||||
|
SPODENew::SPODENew(int root) : SPODE(root), Proposal(SPODE::Xv, SPODE::yv, features, className) {}
|
||||||
|
SPODENew& SPODENew::fit(torch::Tensor& X_, torch::Tensor& y_, vector<string>& features_, string className_, map<string, vector<int>>& states_)
|
||||||
|
{
|
||||||
|
// This first part should go in a Classifier method called fit_local_discretization o fit_float...
|
||||||
|
features = features_;
|
||||||
|
className = className_;
|
||||||
|
Xf = X_;
|
||||||
|
y = y_;
|
||||||
|
// Fills vectors Xv & yv with the data from tensors X_ (discretized) & y
|
||||||
|
fit_local_discretization(states, y);
|
||||||
|
generateTensorXFromVector();
|
||||||
|
// We have discretized the input data
|
||||||
|
// 1st we need to fit the model to build the normal SPODE structure, SPODE::fit initializes the base Bayesian network
|
||||||
|
SPODE::fit(SPODE::Xv, SPODE::yv, features, className, states);
|
||||||
|
localDiscretizationProposal(states, model);
|
||||||
|
generateTensorXFromVector();
|
||||||
|
Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
|
||||||
|
samples = torch::cat({ X, ytmp }, 0);
|
||||||
|
model.fit(SPODE::Xv, SPODE::yv, features, className);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Tensor SPODENew::predict(Tensor& X)
|
||||||
|
{
|
||||||
|
auto Xt = prepareX(X);
|
||||||
|
return SPODE::predict(Xt);
|
||||||
|
}
|
||||||
|
vector<string> SPODENew::graph(const string& name)
|
||||||
|
{
|
||||||
|
return SPODE::graph(name);
|
||||||
|
}
|
||||||
|
}
|
19
src/BayesNet/SPODENew.h
Normal file
19
src/BayesNet/SPODENew.h
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#ifndef SPODENEW_H
|
||||||
|
#define SPODENEW_H
|
||||||
|
#include "SPODE.h"
|
||||||
|
#include "Proposal.h"
|
||||||
|
|
||||||
|
namespace bayesnet {
|
||||||
|
using namespace std;
|
||||||
|
class SPODENew : public SPODE, public Proposal {
|
||||||
|
private:
|
||||||
|
public:
|
||||||
|
explicit SPODENew(int root);
|
||||||
|
virtual ~SPODENew() = default;
|
||||||
|
SPODENew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||||
|
vector<string> graph(const string& name = "SPODE") override;
|
||||||
|
Tensor predict(Tensor& X) override;
|
||||||
|
static inline string version() { return "0.0.1"; };
|
||||||
|
};
|
||||||
|
}
|
||||||
|
#endif // !SPODENew_H
|
@ -8,6 +8,7 @@
|
|||||||
#include "SPODE.h"
|
#include "SPODE.h"
|
||||||
#include "TANNew.h"
|
#include "TANNew.h"
|
||||||
#include "KDBNew.h"
|
#include "KDBNew.h"
|
||||||
|
#include "SPODENew.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class Models {
|
class Models {
|
||||||
private:
|
private:
|
||||||
|
@ -48,9 +48,9 @@ namespace platform {
|
|||||||
cout << setw(6) << right << r["samples"].get<int>() << " ";
|
cout << setw(6) << right << r["samples"].get<int>() << " ";
|
||||||
cout << setw(5) << right << r["features"].get<int>() << " ";
|
cout << setw(5) << right << r["features"].get<int>() << " ";
|
||||||
cout << setw(3) << right << r["classes"].get<int>() << " ";
|
cout << setw(3) << right << r["classes"].get<int>() << " ";
|
||||||
cout << setw(7) << right << r["nodes"].get<float>() << " ";
|
cout << setw(7) << setprecision(2) << fixed << r["nodes"].get<float>() << " ";
|
||||||
cout << setw(7) << right << r["leaves"].get<float>() << " ";
|
cout << setw(7) << setprecision(2) << fixed << r["leaves"].get<float>() << " ";
|
||||||
cout << setw(7) << right << r["depth"].get<float>() << " ";
|
cout << setw(7) << setprecision(2) << fixed << r["depth"].get<float>() << " ";
|
||||||
cout << setw(8) << right << setprecision(6) << fixed << r["score_test"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_test_std"].get<double>() << " ";
|
cout << setw(8) << right << setprecision(6) << fixed << r["score_test"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_test_std"].get<double>() << " ";
|
||||||
cout << setw(10) << right << setprecision(6) << fixed << r["test_time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["test_time_std"].get<double>() << " ";
|
cout << setw(10) << right << setprecision(6) << fixed << r["test_time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["test_time_std"].get<double>() << " ";
|
||||||
cout << " " << r["hyperparameters"].get<string>();
|
cout << " " << r["hyperparameters"].get<string>();
|
||||||
|
@ -6,6 +6,8 @@ static platform::Registrar registrarTN("TANNew",
|
|||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TANNew();});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TANNew();});
|
||||||
static platform::Registrar registrarS("SPODE",
|
static platform::Registrar registrarS("SPODE",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
|
||||||
|
static platform::Registrar registrarSN("SPODENew",
|
||||||
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODENew(2);});
|
||||||
static platform::Registrar registrarK("KDB",
|
static platform::Registrar registrarK("KDB",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::KDB(2);});
|
||||||
static platform::Registrar registrarKN("KDBNew",
|
static platform::Registrar registrarKN("KDBNew",
|
||||||
|
Loading…
Reference in New Issue
Block a user