Begin implementation

This commit is contained in:
Ricardo Montañana Gómez 2023-07-31 19:53:55 +02:00
parent adf650d257
commit a18fbe5594
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
19 changed files with 70 additions and 24 deletions

9
.vscode/launch.json vendored
View File

@ -10,7 +10,7 @@
"-d",
"iris",
"-m",
"TAN",
"TANNew",
"-p",
"/Users/rmontanana/Code/discretizbench/datasets/",
"--tensors"
@ -24,15 +24,12 @@
"program": "${workspaceFolder}/build/src/Platform/main",
"args": [
"-m",
"TAN",
"TANNew",
"-p",
"/Users/rmontanana/Code/discretizbench/datasets",
"--discretize",
"--stratified",
"--title",
"Debug test",
"-d",
"ionosphere"
"iris"
],
"cwd": "${workspaceFolder}/build/src/Platform",
},

View File

@ -9,7 +9,7 @@ namespace bayesnet {
models.push_back(std::make_unique<SPODE>(i));
}
}
vector<string> AODE::graph(string title)
vector<string> AODE::graph(const string& title)
{
return Ensemble::graph(title);
}

View File

@ -9,7 +9,7 @@ namespace bayesnet {
public:
AODE();
virtual ~AODE() {};
vector<string> graph(string title = "AODE") override;
vector<string> graph(const string& title = "AODE") override;
};
}
#endif

View File

@ -16,7 +16,7 @@ namespace bayesnet {
int virtual getNumberOfEdges() = 0;
int virtual getNumberOfStates() = 0;
vector<string> virtual show() = 0;
vector<string> virtual graph(string title = "") = 0;
vector<string> virtual graph(const string& title = "") = 0;
virtual ~BaseClassifier() = default;
const string inline getVersion() const { return "0.1.0"; };
};

View File

@ -1,2 +1,3 @@
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc Mst.cc)
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANNew.cc Mst.cc)
target_link_libraries(BayesNet mdlp "${TORCH_LIBRARIES}")

View File

@ -139,7 +139,7 @@ namespace bayesnet {
}
return result;
}
vector<string> Ensemble::graph(string title)
vector<string> Ensemble::graph(const string& title)
{
auto result = vector<string>();
for (auto i = 0; i < n_models; ++i) {

View File

@ -40,7 +40,7 @@ namespace bayesnet {
int getNumberOfEdges() override;
int getNumberOfStates() override;
vector<string> show() override;
vector<string> graph(string title) override;
vector<string> graph(const string& title) override;
};
}
#endif

View File

@ -79,11 +79,12 @@ namespace bayesnet {
exit_cond = num == n_edges || candidates.size(0) == 0;
}
}
vector<string> KDB::graph(string title)
vector<string> KDB::graph(const string& title)
{
string header{ title };
if (title == "KDB") {
title += " (k=" + to_string(k) + ", theta=" + to_string(theta) + ")";
header += " (k=" + to_string(k) + ", theta=" + to_string(theta) + ")";
}
return model.graph(title);
return model.graph(header);
}
}

View File

@ -15,7 +15,7 @@ namespace bayesnet {
public:
explicit KDB(int k, float theta = 0.03);
virtual ~KDB() {};
vector<string> graph(string name = "KDB") override;
vector<string> graph(const string& name = "KDB") override;
};
}
#endif

View File

@ -17,7 +17,7 @@ namespace bayesnet {
}
}
}
vector<string> SPODE::graph(string name )
vector<string> SPODE::graph(const string& name)
{
return model.graph(name);
}

View File

@ -11,7 +11,7 @@ namespace bayesnet {
public:
explicit SPODE(int root);
virtual ~SPODE() {};
vector<string> graph(string name = "SPODE") override;
vector<string> graph(const string& name = "SPODE") override;
};
}
#endif

View File

@ -34,7 +34,7 @@ namespace bayesnet {
model.addEdge(className, feature);
}
}
vector<string> TAN::graph(string title)
vector<string> TAN::graph(const string& title)
{
return model.graph(title);
}

View File

@ -11,7 +11,7 @@ namespace bayesnet {
public:
TAN();
virtual ~TAN() {};
vector<string> graph(string name = "TAN") override;
vector<string> graph(const string& name = "TAN") override;
};
}
#endif

23
src/BayesNet/TANNew.cc Normal file
View File

@ -0,0 +1,23 @@
#include "TANNew.h"
namespace bayesnet {
using namespace std;
TANNew::TANNew() : TAN(), discretizer{ mdlp::CPPFImdlp() } {}
TANNew::~TANNew() {}
TANNew& TANNew::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
{
/*
Hay que discretizar los datos de entrada y luego en predict discretizar también con el mmismo modelo, hacer un transform solamente.
*/
TAN::fit(X, y, features, className, states);
return *this;
}
void TANNew::train()
{
TAN::train();
}
vector<string> TANNew::graph(const string& name)
{
return TAN::graph(name);
}
}

21
src/BayesNet/TANNew.h Normal file
View File

@ -0,0 +1,21 @@
#ifndef TANNEW_H
#define TANNEW_H
#include "TAN.h"
#include "CPPFImdlp.h"
namespace bayesnet {
using namespace std;
class TANNew : public TAN {
private:
mdlp::CPPFImdlp discretizer;
public:
TANNew();
virtual ~TANNew();
void train() override;
TANNew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
vector<string> graph(const string& name = "TAN") override;
static inline string version() { return "0.0.1"; };
};
}
#endif // !TANNEW_H

View File

@ -207,9 +207,9 @@ namespace platform {
if (discretize) {
Xd = discretizeDataset(Xv, yv);
computeStates();
n_samples = Xd[0].size();
n_features = Xd.size();
}
n_samples = Xv[0].size();
n_features = Xv.size();
loaded = true;
}
void Dataset::buildTensors()

View File

@ -104,7 +104,7 @@ namespace platform {
void Experiment::cross_validation(const string& path, const string& fileName)
{
auto datasets = platform::Datasets(path, true, platform::ARFF);
auto datasets = platform::Datasets(path, discretized, platform::ARFF);
// Get dataset
auto [X, y] = datasets.getTensors(fileName);
auto states = datasets.getStates(fileName);

View File

@ -6,6 +6,7 @@
#include "TAN.h"
#include "KDB.h"
#include "SPODE.h"
#include "TANNew.h"
namespace platform {
class Models {
private:

View File

@ -2,6 +2,8 @@
#define MODEL_REGISTER_H
static platform::Registrar registrarT("TAN",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
static platform::Registrar registrarTN("TANNew",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TANNew();});
static platform::Registrar registrarS("SPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
static platform::Registrar registrarK("KDB",