diff --git a/.vscode/launch.json b/.vscode/launch.json index 4deb176..951d4c3 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,7 +10,7 @@ "-d", "iris", "-m", - "TAN", + "TANNew", "-p", "/Users/rmontanana/Code/discretizbench/datasets/", "--tensors" @@ -24,15 +24,12 @@ "program": "${workspaceFolder}/build/src/Platform/main", "args": [ "-m", - "TAN", + "TANNew", "-p", "/Users/rmontanana/Code/discretizbench/datasets", - "--discretize", "--stratified", - "--title", - "Debug test", "-d", - "ionosphere" + "iris" ], "cwd": "${workspaceFolder}/build/src/Platform", }, diff --git a/src/BayesNet/AODE.cc b/src/BayesNet/AODE.cc index f81b191..6ec9361 100644 --- a/src/BayesNet/AODE.cc +++ b/src/BayesNet/AODE.cc @@ -9,7 +9,7 @@ namespace bayesnet { models.push_back(std::make_unique(i)); } } - vector AODE::graph(string title) + vector AODE::graph(const string& title) { return Ensemble::graph(title); } diff --git a/src/BayesNet/AODE.h b/src/BayesNet/AODE.h index bc859e7..9a698e1 100644 --- a/src/BayesNet/AODE.h +++ b/src/BayesNet/AODE.h @@ -9,7 +9,7 @@ namespace bayesnet { public: AODE(); virtual ~AODE() {}; - vector graph(string title = "AODE") override; + vector graph(const string& title = "AODE") override; }; } #endif \ No newline at end of file diff --git a/src/BayesNet/BaseClassifier.h b/src/BayesNet/BaseClassifier.h index 16daaa6..6ade380 100644 --- a/src/BayesNet/BaseClassifier.h +++ b/src/BayesNet/BaseClassifier.h @@ -16,7 +16,7 @@ namespace bayesnet { int virtual getNumberOfEdges() = 0; int virtual getNumberOfStates() = 0; vector virtual show() = 0; - vector virtual graph(string title = "") = 0; + vector virtual graph(const string& title = "") = 0; virtual ~BaseClassifier() = default; const string inline getVersion() const { return "0.1.0"; }; }; diff --git a/src/BayesNet/CMakeLists.txt b/src/BayesNet/CMakeLists.txt index 6433d93..f502701 100644 --- a/src/BayesNet/CMakeLists.txt +++ b/src/BayesNet/CMakeLists.txt @@ -1,2 +1,3 @@ -add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc Mst.cc) -target_link_libraries(BayesNet "${TORCH_LIBRARIES}") \ No newline at end of file +include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) +add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANNew.cc Mst.cc) +target_link_libraries(BayesNet mdlp "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/BayesNet/Ensemble.cc b/src/BayesNet/Ensemble.cc index c18b51e..ca64798 100644 --- a/src/BayesNet/Ensemble.cc +++ b/src/BayesNet/Ensemble.cc @@ -139,7 +139,7 @@ namespace bayesnet { } return result; } - vector Ensemble::graph(string title) + vector Ensemble::graph(const string& title) { auto result = vector(); for (auto i = 0; i < n_models; ++i) { diff --git a/src/BayesNet/Ensemble.h b/src/BayesNet/Ensemble.h index 2f85092..4f5c7c6 100644 --- a/src/BayesNet/Ensemble.h +++ b/src/BayesNet/Ensemble.h @@ -40,7 +40,7 @@ namespace bayesnet { int getNumberOfEdges() override; int getNumberOfStates() override; vector show() override; - vector graph(string title) override; + vector graph(const string& title) override; }; } #endif diff --git a/src/BayesNet/KDB.cc b/src/BayesNet/KDB.cc index b041dac..cfdf750 100644 --- a/src/BayesNet/KDB.cc +++ b/src/BayesNet/KDB.cc @@ -79,11 +79,12 @@ namespace bayesnet { exit_cond = num == n_edges || candidates.size(0) == 0; } } - vector KDB::graph(string title) + vector KDB::graph(const string& title) { + string header{ title }; if (title == "KDB") { - title += " (k=" + to_string(k) + ", theta=" + to_string(theta) + ")"; + header += " (k=" + to_string(k) + ", theta=" + to_string(theta) + ")"; } - return model.graph(title); + return model.graph(header); } } \ No newline at end of file diff --git a/src/BayesNet/KDB.h b/src/BayesNet/KDB.h index e3f257f..11a69d7 100644 --- a/src/BayesNet/KDB.h +++ b/src/BayesNet/KDB.h @@ -15,7 +15,7 @@ namespace bayesnet { public: explicit KDB(int k, float theta = 0.03); virtual ~KDB() {}; - vector graph(string name = "KDB") override; + vector graph(const string& name = "KDB") override; }; } #endif \ No newline at end of file diff --git a/src/BayesNet/SPODE.cc b/src/BayesNet/SPODE.cc index 68ff0b9..a627cca 100644 --- a/src/BayesNet/SPODE.cc +++ b/src/BayesNet/SPODE.cc @@ -17,7 +17,7 @@ namespace bayesnet { } } } - vector SPODE::graph(string name ) + vector SPODE::graph(const string& name) { return model.graph(name); } diff --git a/src/BayesNet/SPODE.h b/src/BayesNet/SPODE.h index 30f0b46..4625714 100644 --- a/src/BayesNet/SPODE.h +++ b/src/BayesNet/SPODE.h @@ -11,7 +11,7 @@ namespace bayesnet { public: explicit SPODE(int root); virtual ~SPODE() {}; - vector graph(string name = "SPODE") override; + vector graph(const string& name = "SPODE") override; }; } #endif \ No newline at end of file diff --git a/src/BayesNet/TAN.cc b/src/BayesNet/TAN.cc index 51f0c1b..df6561a 100644 --- a/src/BayesNet/TAN.cc +++ b/src/BayesNet/TAN.cc @@ -34,7 +34,7 @@ namespace bayesnet { model.addEdge(className, feature); } } - vector TAN::graph(string title) + vector TAN::graph(const string& title) { return model.graph(title); } diff --git a/src/BayesNet/TAN.h b/src/BayesNet/TAN.h index ce9b10a..5c7cf49 100644 --- a/src/BayesNet/TAN.h +++ b/src/BayesNet/TAN.h @@ -11,7 +11,7 @@ namespace bayesnet { public: TAN(); virtual ~TAN() {}; - vector graph(string name = "TAN") override; + vector graph(const string& name = "TAN") override; }; } #endif \ No newline at end of file diff --git a/src/BayesNet/TANNew.cc b/src/BayesNet/TANNew.cc new file mode 100644 index 0000000..6e7ff4e --- /dev/null +++ b/src/BayesNet/TANNew.cc @@ -0,0 +1,23 @@ +#include "TANNew.h" + +namespace bayesnet { + using namespace std; + TANNew::TANNew() : TAN(), discretizer{ mdlp::CPPFImdlp() } {} + TANNew::~TANNew() {} + TANNew& TANNew::fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) + { + /* + Hay que discretizar los datos de entrada y luego en predict discretizar tambiƩn con el mmismo modelo, hacer un transform solamente. + */ + TAN::fit(X, y, features, className, states); + return *this; + } + void TANNew::train() + { + TAN::train(); + } + vector TANNew::graph(const string& name) + { + return TAN::graph(name); + } +} \ No newline at end of file diff --git a/src/BayesNet/TANNew.h b/src/BayesNet/TANNew.h new file mode 100644 index 0000000..b0bddb8 --- /dev/null +++ b/src/BayesNet/TANNew.h @@ -0,0 +1,21 @@ +#ifndef TANNEW_H +#define TANNEW_H +#include "TAN.h" +#include "CPPFImdlp.h" + +namespace bayesnet { + using namespace std; + class TANNew : public TAN { + private: + mdlp::CPPFImdlp discretizer; + public: + TANNew(); + virtual ~TANNew(); + void train() override; + TANNew& fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) override; + vector graph(const string& name = "TAN") override; + static inline string version() { return "0.0.1"; }; + }; +} + +#endif // !TANNEW_H \ No newline at end of file diff --git a/src/Platform/Datasets.cc b/src/Platform/Datasets.cc index 11b83ac..6756148 100644 --- a/src/Platform/Datasets.cc +++ b/src/Platform/Datasets.cc @@ -207,9 +207,9 @@ namespace platform { if (discretize) { Xd = discretizeDataset(Xv, yv); computeStates(); - n_samples = Xd[0].size(); - n_features = Xd.size(); } + n_samples = Xv[0].size(); + n_features = Xv.size(); loaded = true; } void Dataset::buildTensors() diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 97e7289..58f23cc 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -104,7 +104,7 @@ namespace platform { void Experiment::cross_validation(const string& path, const string& fileName) { - auto datasets = platform::Datasets(path, true, platform::ARFF); + auto datasets = platform::Datasets(path, discretized, platform::ARFF); // Get dataset auto [X, y] = datasets.getTensors(fileName); auto states = datasets.getStates(fileName); diff --git a/src/Platform/Models.h b/src/Platform/Models.h index 0bb8d51..ae675ea 100644 --- a/src/Platform/Models.h +++ b/src/Platform/Models.h @@ -6,6 +6,7 @@ #include "TAN.h" #include "KDB.h" #include "SPODE.h" +#include "TANNew.h" namespace platform { class Models { private: diff --git a/src/Platform/modelRegister.h b/src/Platform/modelRegister.h index a4188bc..e6788cb 100644 --- a/src/Platform/modelRegister.h +++ b/src/Platform/modelRegister.h @@ -2,6 +2,8 @@ #define MODEL_REGISTER_H static platform::Registrar registrarT("TAN", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::TAN();}); +static platform::Registrar registrarTN("TANNew", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::TANNew();}); static platform::Registrar registrarS("SPODE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::SPODE(2);}); static platform::Registrar registrarK("KDB",