From 4d4780c1d5797084f32c70aac3d735c658f10e9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 15 Aug 2023 16:16:04 +0200 Subject: [PATCH] Add BoostAODE model based on AODE --- src/BayesNet/BoostAODE.cc | 16 ++++++++++++++++ src/BayesNet/BoostAODE.h | 15 +++++++++++++++ src/BayesNet/CMakeLists.txt | 2 +- src/Platform/Models.h | 1 + src/Platform/modelRegister.h | 2 ++ 5 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 src/BayesNet/BoostAODE.cc create mode 100644 src/BayesNet/BoostAODE.h diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc new file mode 100644 index 0000000..baafa16 --- /dev/null +++ b/src/BayesNet/BoostAODE.cc @@ -0,0 +1,16 @@ +#include "BoostAODE.h" + +namespace bayesnet { + BoostAODE::BoostAODE() : Ensemble() {} + void BoostAODE::buildModel(const torch::Tensor& weights) + { + models.clear(); + for (int i = 0; i < features.size(); ++i) { + models.push_back(std::make_unique(i)); + } + } + vector BoostAODE::graph(const string& title) const + { + return Ensemble::graph(title); + } +} \ No newline at end of file diff --git a/src/BayesNet/BoostAODE.h b/src/BayesNet/BoostAODE.h new file mode 100644 index 0000000..66a871f --- /dev/null +++ b/src/BayesNet/BoostAODE.h @@ -0,0 +1,15 @@ +#ifndef BOOSTAODE_H +#define BOOSTAODE_H +#include "Ensemble.h" +#include "SPODE.h" +namespace bayesnet { + class BoostAODE : public Ensemble { + protected: + void buildModel(const torch::Tensor& weights) override; + public: + BoostAODE(); + virtual ~BoostAODE() {}; + vector graph(const string& title = "BoostAODE") const override; + }; +} +#endif \ No newline at end of file diff --git a/src/BayesNet/CMakeLists.txt b/src/BayesNet/CMakeLists.txt index a2b9126..a94d8e9 100644 --- a/src/BayesNet/CMakeLists.txt +++ b/src/BayesNet/CMakeLists.txt @@ -3,5 +3,5 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/Files) include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet) include_directories(${BayesNet_SOURCE_DIR}/src/Platform) add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc - KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANLd.cc KDBLd.cc SPODELd.cc AODELd.cc Mst.cc Proposal.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc) + KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANLd.cc KDBLd.cc SPODELd.cc AODELd.cc BoostAODE.cc Mst.cc Proposal.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc) target_link_libraries(BayesNet mdlp ArffFiles "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Platform/Models.h b/src/Platform/Models.h index 0e3184b..6c5d437 100644 --- a/src/Platform/Models.h +++ b/src/Platform/Models.h @@ -10,6 +10,7 @@ #include "KDBLd.h" #include "SPODELd.h" #include "AODELd.h" +#include "BoostAODE.h" namespace platform { class Models { private: diff --git a/src/Platform/modelRegister.h b/src/Platform/modelRegister.h index 6ae9af3..04b48cf 100644 --- a/src/Platform/modelRegister.h +++ b/src/Platform/modelRegister.h @@ -16,4 +16,6 @@ static platform::Registrar registrarA("AODE", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();}); static platform::Registrar registrarALD("AODELd", [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();}); +static platform::Registrar registrarBA("BoostAODE", + [](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();}); #endif \ No newline at end of file