Begin implementation

This commit is contained in:
Ricardo Montañana Gómez 2023-07-31 19:53:55 +02:00
parent adf650d257
commit a18fbe5594
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
19 changed files with 70 additions and 24 deletions

9
.vscode/launch.json vendored
View File

@ -10,7 +10,7 @@
"-d", "-d",
"iris", "iris",
"-m", "-m",
"TAN", "TANNew",
"-p", "-p",
"/Users/rmontanana/Code/discretizbench/datasets/", "/Users/rmontanana/Code/discretizbench/datasets/",
"--tensors" "--tensors"
@ -24,15 +24,12 @@
"program": "${workspaceFolder}/build/src/Platform/main", "program": "${workspaceFolder}/build/src/Platform/main",
"args": [ "args": [
"-m", "-m",
"TAN", "TANNew",
"-p", "-p",
"/Users/rmontanana/Code/discretizbench/datasets", "/Users/rmontanana/Code/discretizbench/datasets",
"--discretize",
"--stratified", "--stratified",
"--title",
"Debug test",
"-d", "-d",
"ionosphere" "iris"
], ],
"cwd": "${workspaceFolder}/build/src/Platform", "cwd": "${workspaceFolder}/build/src/Platform",
}, },

View File

@ -9,7 +9,7 @@ namespace bayesnet {
models.push_back(std::make_unique<SPODE>(i)); models.push_back(std::make_unique<SPODE>(i));
} }
} }
vector<string> AODE::graph(string title) vector<string> AODE::graph(const string& title)
{ {
return Ensemble::graph(title); return Ensemble::graph(title);
} }

View File

@ -9,7 +9,7 @@ namespace bayesnet {
public: public:
AODE(); AODE();
virtual ~AODE() {}; virtual ~AODE() {};
vector<string> graph(string title = "AODE") override; vector<string> graph(const string& title = "AODE") override;
}; };
} }
#endif #endif

View File

@ -16,7 +16,7 @@ namespace bayesnet {
int virtual getNumberOfEdges() = 0; int virtual getNumberOfEdges() = 0;
int virtual getNumberOfStates() = 0; int virtual getNumberOfStates() = 0;
vector<string> virtual show() = 0; vector<string> virtual show() = 0;
vector<string> virtual graph(string title = "") = 0; vector<string> virtual graph(const string& title = "") = 0;
virtual ~BaseClassifier() = default; virtual ~BaseClassifier() = default;
const string inline getVersion() const { return "0.1.0"; }; const string inline getVersion() const { return "0.1.0"; };
}; };

View File

@ -1,2 +1,3 @@
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc Mst.cc) include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}") add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANNew.cc Mst.cc)
target_link_libraries(BayesNet mdlp "${TORCH_LIBRARIES}")

View File

@ -139,7 +139,7 @@ namespace bayesnet {
} }
return result; return result;
} }
vector<string> Ensemble::graph(string title) vector<string> Ensemble::graph(const string& title)
{ {
auto result = vector<string>(); auto result = vector<string>();
for (auto i = 0; i < n_models; ++i) { for (auto i = 0; i < n_models; ++i) {

View File

@ -40,7 +40,7 @@ namespace bayesnet {
int getNumberOfEdges() override; int getNumberOfEdges() override;
int getNumberOfStates() override; int getNumberOfStates() override;
vector<string> show() override; vector<string> show() override;
vector<string> graph(string title) override; vector<string> graph(const string& title) override;
}; };
} }
#endif #endif

View File

@ -79,11 +79,12 @@ namespace bayesnet {
exit_cond = num == n_edges || candidates.size(0) == 0; exit_cond = num == n_edges || candidates.size(0) == 0;
} }
} }
vector<string> KDB::graph(string title) vector<string> KDB::graph(const string& title)
{ {
string header{ title };
if (title == "KDB") { if (title == "KDB") {
title += " (k=" + to_string(k) + ", theta=" + to_string(theta) + ")"; header += " (k=" + to_string(k) + ", theta=" + to_string(theta) + ")";
} }
return model.graph(title); return model.graph(header);
} }
} }

View File

@ -15,7 +15,7 @@ namespace bayesnet {
public: public:
explicit KDB(int k, float theta = 0.03); explicit KDB(int k, float theta = 0.03);
virtual ~KDB() {}; virtual ~KDB() {};
vector<string> graph(string name = "KDB") override; vector<string> graph(const string& name = "KDB") override;
}; };
} }
#endif #endif

View File

@ -17,7 +17,7 @@ namespace bayesnet {
} }
} }
} }
vector<string> SPODE::graph(string name ) vector<string> SPODE::graph(const string& name)
{ {
return model.graph(name); return model.graph(name);
} }

View File

@ -11,7 +11,7 @@ namespace bayesnet {
public: public:
explicit SPODE(int root); explicit SPODE(int root);
virtual ~SPODE() {}; virtual ~SPODE() {};
vector<string> graph(string name = "SPODE") override; vector<string> graph(const string& name = "SPODE") override;
}; };
} }
#endif #endif

View File

@ -34,7 +34,7 @@ namespace bayesnet {
model.addEdge(className, feature); model.addEdge(className, feature);
} }
} }
vector<string> TAN::graph(string title) vector<string> TAN::graph(const string& title)
{ {
return model.graph(title); return model.graph(title);
} }

View File

@ -11,7 +11,7 @@ namespace bayesnet {
public: public:
TAN(); TAN();
virtual ~TAN() {}; virtual ~TAN() {};
vector<string> graph(string name = "TAN") override; vector<string> graph(const string& name = "TAN") override;
}; };
} }
#endif #endif

23
src/BayesNet/TANNew.cc Normal file
View File

@ -0,0 +1,23 @@
#include "TANNew.h"
namespace bayesnet {
using namespace std;
TANNew::TANNew() : TAN(), discretizer{ mdlp::CPPFImdlp() } {}
TANNew::~TANNew() {}
TANNew& TANNew::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
{
/*
Hay que discretizar los datos de entrada y luego en predict discretizar también con el mmismo modelo, hacer un transform solamente.
*/
TAN::fit(X, y, features, className, states);
return *this;
}
void TANNew::train()
{
TAN::train();
}
vector<string> TANNew::graph(const string& name)
{
return TAN::graph(name);
}
}

21
src/BayesNet/TANNew.h Normal file
View File

@ -0,0 +1,21 @@
#ifndef TANNEW_H
#define TANNEW_H
#include "TAN.h"
#include "CPPFImdlp.h"
namespace bayesnet {
using namespace std;
class TANNew : public TAN {
private:
mdlp::CPPFImdlp discretizer;
public:
TANNew();
virtual ~TANNew();
void train() override;
TANNew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
vector<string> graph(const string& name = "TAN") override;
static inline string version() { return "0.0.1"; };
};
}
#endif // !TANNEW_H

View File

@ -207,9 +207,9 @@ namespace platform {
if (discretize) { if (discretize) {
Xd = discretizeDataset(Xv, yv); Xd = discretizeDataset(Xv, yv);
computeStates(); computeStates();
n_samples = Xd[0].size();
n_features = Xd.size();
} }
n_samples = Xv[0].size();
n_features = Xv.size();
loaded = true; loaded = true;
} }
void Dataset::buildTensors() void Dataset::buildTensors()

View File

@ -104,7 +104,7 @@ namespace platform {
void Experiment::cross_validation(const string& path, const string& fileName) void Experiment::cross_validation(const string& path, const string& fileName)
{ {
auto datasets = platform::Datasets(path, true, platform::ARFF); auto datasets = platform::Datasets(path, discretized, platform::ARFF);
// Get dataset // Get dataset
auto [X, y] = datasets.getTensors(fileName); auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName); auto states = datasets.getStates(fileName);

View File

@ -6,6 +6,7 @@
#include "TAN.h" #include "TAN.h"
#include "KDB.h" #include "KDB.h"
#include "SPODE.h" #include "SPODE.h"
#include "TANNew.h"
namespace platform { namespace platform {
class Models { class Models {
private: private:

View File

@ -2,6 +2,8 @@
#define MODEL_REGISTER_H #define MODEL_REGISTER_H
static platform::Registrar registrarT("TAN", static platform::Registrar registrarT("TAN",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();}); [](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
static platform::Registrar registrarTN("TANNew",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TANNew();});
static platform::Registrar registrarS("SPODE", static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);}); [](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarK("KDB", static platform::Registrar registrarK("KDB",