Begin implementation
This commit is contained in:
parent
adf650d257
commit
a18fbe5594
9
.vscode/launch.json
vendored
9
.vscode/launch.json
vendored
@ -10,7 +10,7 @@
|
|||||||
"-d",
|
"-d",
|
||||||
"iris",
|
"iris",
|
||||||
"-m",
|
"-m",
|
||||||
"TAN",
|
"TANNew",
|
||||||
"-p",
|
"-p",
|
||||||
"/Users/rmontanana/Code/discretizbench/datasets/",
|
"/Users/rmontanana/Code/discretizbench/datasets/",
|
||||||
"--tensors"
|
"--tensors"
|
||||||
@ -24,15 +24,12 @@
|
|||||||
"program": "${workspaceFolder}/build/src/Platform/main",
|
"program": "${workspaceFolder}/build/src/Platform/main",
|
||||||
"args": [
|
"args": [
|
||||||
"-m",
|
"-m",
|
||||||
"TAN",
|
"TANNew",
|
||||||
"-p",
|
"-p",
|
||||||
"/Users/rmontanana/Code/discretizbench/datasets",
|
"/Users/rmontanana/Code/discretizbench/datasets",
|
||||||
"--discretize",
|
|
||||||
"--stratified",
|
"--stratified",
|
||||||
"--title",
|
|
||||||
"Debug test",
|
|
||||||
"-d",
|
"-d",
|
||||||
"ionosphere"
|
"iris"
|
||||||
],
|
],
|
||||||
"cwd": "${workspaceFolder}/build/src/Platform",
|
"cwd": "${workspaceFolder}/build/src/Platform",
|
||||||
},
|
},
|
||||||
|
@ -9,7 +9,7 @@ namespace bayesnet {
|
|||||||
models.push_back(std::make_unique<SPODE>(i));
|
models.push_back(std::make_unique<SPODE>(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
vector<string> AODE::graph(string title)
|
vector<string> AODE::graph(const string& title)
|
||||||
{
|
{
|
||||||
return Ensemble::graph(title);
|
return Ensemble::graph(title);
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ namespace bayesnet {
|
|||||||
public:
|
public:
|
||||||
AODE();
|
AODE();
|
||||||
virtual ~AODE() {};
|
virtual ~AODE() {};
|
||||||
vector<string> graph(string title = "AODE") override;
|
vector<string> graph(const string& title = "AODE") override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -16,7 +16,7 @@ namespace bayesnet {
|
|||||||
int virtual getNumberOfEdges() = 0;
|
int virtual getNumberOfEdges() = 0;
|
||||||
int virtual getNumberOfStates() = 0;
|
int virtual getNumberOfStates() = 0;
|
||||||
vector<string> virtual show() = 0;
|
vector<string> virtual show() = 0;
|
||||||
vector<string> virtual graph(string title = "") = 0;
|
vector<string> virtual graph(const string& title = "") = 0;
|
||||||
virtual ~BaseClassifier() = default;
|
virtual ~BaseClassifier() = default;
|
||||||
const string inline getVersion() const { return "0.1.0"; };
|
const string inline getVersion() const { return "0.1.0"; };
|
||||||
};
|
};
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc Mst.cc)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
||||||
target_link_libraries(BayesNet "${TORCH_LIBRARIES}")
|
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANNew.cc Mst.cc)
|
||||||
|
target_link_libraries(BayesNet mdlp "${TORCH_LIBRARIES}")
|
@ -139,7 +139,7 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
vector<string> Ensemble::graph(string title)
|
vector<string> Ensemble::graph(const string& title)
|
||||||
{
|
{
|
||||||
auto result = vector<string>();
|
auto result = vector<string>();
|
||||||
for (auto i = 0; i < n_models; ++i) {
|
for (auto i = 0; i < n_models; ++i) {
|
||||||
|
@ -40,7 +40,7 @@ namespace bayesnet {
|
|||||||
int getNumberOfEdges() override;
|
int getNumberOfEdges() override;
|
||||||
int getNumberOfStates() override;
|
int getNumberOfStates() override;
|
||||||
vector<string> show() override;
|
vector<string> show() override;
|
||||||
vector<string> graph(string title) override;
|
vector<string> graph(const string& title) override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -79,11 +79,12 @@ namespace bayesnet {
|
|||||||
exit_cond = num == n_edges || candidates.size(0) == 0;
|
exit_cond = num == n_edges || candidates.size(0) == 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
vector<string> KDB::graph(string title)
|
vector<string> KDB::graph(const string& title)
|
||||||
{
|
{
|
||||||
|
string header{ title };
|
||||||
if (title == "KDB") {
|
if (title == "KDB") {
|
||||||
title += " (k=" + to_string(k) + ", theta=" + to_string(theta) + ")";
|
header += " (k=" + to_string(k) + ", theta=" + to_string(theta) + ")";
|
||||||
}
|
}
|
||||||
return model.graph(title);
|
return model.graph(header);
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -15,7 +15,7 @@ namespace bayesnet {
|
|||||||
public:
|
public:
|
||||||
explicit KDB(int k, float theta = 0.03);
|
explicit KDB(int k, float theta = 0.03);
|
||||||
virtual ~KDB() {};
|
virtual ~KDB() {};
|
||||||
vector<string> graph(string name = "KDB") override;
|
vector<string> graph(const string& name = "KDB") override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -17,7 +17,7 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
vector<string> SPODE::graph(string name )
|
vector<string> SPODE::graph(const string& name)
|
||||||
{
|
{
|
||||||
return model.graph(name);
|
return model.graph(name);
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ namespace bayesnet {
|
|||||||
public:
|
public:
|
||||||
explicit SPODE(int root);
|
explicit SPODE(int root);
|
||||||
virtual ~SPODE() {};
|
virtual ~SPODE() {};
|
||||||
vector<string> graph(string name = "SPODE") override;
|
vector<string> graph(const string& name = "SPODE") override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -34,7 +34,7 @@ namespace bayesnet {
|
|||||||
model.addEdge(className, feature);
|
model.addEdge(className, feature);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
vector<string> TAN::graph(string title)
|
vector<string> TAN::graph(const string& title)
|
||||||
{
|
{
|
||||||
return model.graph(title);
|
return model.graph(title);
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ namespace bayesnet {
|
|||||||
public:
|
public:
|
||||||
TAN();
|
TAN();
|
||||||
virtual ~TAN() {};
|
virtual ~TAN() {};
|
||||||
vector<string> graph(string name = "TAN") override;
|
vector<string> graph(const string& name = "TAN") override;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
23
src/BayesNet/TANNew.cc
Normal file
23
src/BayesNet/TANNew.cc
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
#include "TANNew.h"
|
||||||
|
|
||||||
|
namespace bayesnet {
|
||||||
|
using namespace std;
|
||||||
|
TANNew::TANNew() : TAN(), discretizer{ mdlp::CPPFImdlp() } {}
|
||||||
|
TANNew::~TANNew() {}
|
||||||
|
TANNew& TANNew::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
|
||||||
|
{
|
||||||
|
/*
|
||||||
|
Hay que discretizar los datos de entrada y luego en predict discretizar también con el mmismo modelo, hacer un transform solamente.
|
||||||
|
*/
|
||||||
|
TAN::fit(X, y, features, className, states);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
void TANNew::train()
|
||||||
|
{
|
||||||
|
TAN::train();
|
||||||
|
}
|
||||||
|
vector<string> TANNew::graph(const string& name)
|
||||||
|
{
|
||||||
|
return TAN::graph(name);
|
||||||
|
}
|
||||||
|
}
|
21
src/BayesNet/TANNew.h
Normal file
21
src/BayesNet/TANNew.h
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#ifndef TANNEW_H
|
||||||
|
#define TANNEW_H
|
||||||
|
#include "TAN.h"
|
||||||
|
#include "CPPFImdlp.h"
|
||||||
|
|
||||||
|
namespace bayesnet {
|
||||||
|
using namespace std;
|
||||||
|
class TANNew : public TAN {
|
||||||
|
private:
|
||||||
|
mdlp::CPPFImdlp discretizer;
|
||||||
|
public:
|
||||||
|
TANNew();
|
||||||
|
virtual ~TANNew();
|
||||||
|
void train() override;
|
||||||
|
TANNew& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
|
||||||
|
vector<string> graph(const string& name = "TAN") override;
|
||||||
|
static inline string version() { return "0.0.1"; };
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // !TANNEW_H
|
@ -207,9 +207,9 @@ namespace platform {
|
|||||||
if (discretize) {
|
if (discretize) {
|
||||||
Xd = discretizeDataset(Xv, yv);
|
Xd = discretizeDataset(Xv, yv);
|
||||||
computeStates();
|
computeStates();
|
||||||
n_samples = Xd[0].size();
|
|
||||||
n_features = Xd.size();
|
|
||||||
}
|
}
|
||||||
|
n_samples = Xv[0].size();
|
||||||
|
n_features = Xv.size();
|
||||||
loaded = true;
|
loaded = true;
|
||||||
}
|
}
|
||||||
void Dataset::buildTensors()
|
void Dataset::buildTensors()
|
||||||
|
@ -104,7 +104,7 @@ namespace platform {
|
|||||||
|
|
||||||
void Experiment::cross_validation(const string& path, const string& fileName)
|
void Experiment::cross_validation(const string& path, const string& fileName)
|
||||||
{
|
{
|
||||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
auto datasets = platform::Datasets(path, discretized, platform::ARFF);
|
||||||
// Get dataset
|
// Get dataset
|
||||||
auto [X, y] = datasets.getTensors(fileName);
|
auto [X, y] = datasets.getTensors(fileName);
|
||||||
auto states = datasets.getStates(fileName);
|
auto states = datasets.getStates(fileName);
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include "TAN.h"
|
#include "TAN.h"
|
||||||
#include "KDB.h"
|
#include "KDB.h"
|
||||||
#include "SPODE.h"
|
#include "SPODE.h"
|
||||||
|
#include "TANNew.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class Models {
|
class Models {
|
||||||
private:
|
private:
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
#define MODEL_REGISTER_H
|
#define MODEL_REGISTER_H
|
||||||
static platform::Registrar registrarT("TAN",
|
static platform::Registrar registrarT("TAN",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();});
|
||||||
|
static platform::Registrar registrarTN("TANNew",
|
||||||
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::TANNew();});
|
||||||
static platform::Registrar registrarS("SPODE",
|
static platform::Registrar registrarS("SPODE",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);});
|
||||||
static platform::Registrar registrarK("KDB",
|
static platform::Registrar registrarK("KDB",
|
||||||
|
Loading…
Reference in New Issue
Block a user