Add BoostAODE model based on AODE

This commit is contained in:
Ricardo Montañana Gómez 2023-08-15 16:16:04 +02:00
parent fa612c531e
commit 4d4780c1d5
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
5 changed files with 35 additions and 1 deletions

16
src/BayesNet/BoostAODE.cc Normal file
View File

@ -0,0 +1,16 @@
#include "BoostAODE.h"
namespace bayesnet {
BoostAODE::BoostAODE() : Ensemble() {}
void BoostAODE::buildModel(const torch::Tensor& weights)
{
models.clear();
for (int i = 0; i < features.size(); ++i) {
models.push_back(std::make_unique<SPODE>(i));
}
}
vector<string> BoostAODE::graph(const string& title) const
{
return Ensemble::graph(title);
}
}

15
src/BayesNet/BoostAODE.h Normal file
View File

@ -0,0 +1,15 @@
#ifndef BOOSTAODE_H
#define BOOSTAODE_H
#include "Ensemble.h"
#include "SPODE.h"
namespace bayesnet {
class BoostAODE : public Ensemble {
protected:
void buildModel(const torch::Tensor& weights) override;
public:
BoostAODE();
virtual ~BoostAODE() {};
vector<string> graph(const string& title = "BoostAODE") const override;
};
}
#endif

View File

@ -3,5 +3,5 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/Files)
include_directories(${BayesNet_SOURCE_DIR}/src/BayesNet)
include_directories(${BayesNet_SOURCE_DIR}/src/Platform)
add_library(BayesNet bayesnetUtils.cc Network.cc Node.cc BayesMetrics.cc Classifier.cc
KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANLd.cc KDBLd.cc SPODELd.cc AODELd.cc Mst.cc Proposal.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)
KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc TANLd.cc KDBLd.cc SPODELd.cc AODELd.cc BoostAODE.cc Mst.cc Proposal.cc ${BayesNet_SOURCE_DIR}/src/Platform/Models.cc)
target_link_libraries(BayesNet mdlp ArffFiles "${TORCH_LIBRARIES}")

View File

@ -10,6 +10,7 @@
#include "KDBLd.h"
#include "SPODELd.h"
#include "AODELd.h"
#include "BoostAODE.h"
namespace platform {
class Models {
private:

View File

@ -16,4 +16,6 @@ static platform::Registrar registrarA("AODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODE();});
static platform::Registrar registrarALD("AODELd",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();});
static platform::Registrar registrarBA("BoostAODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();});
#endif