Create BoostA2DE base class
This commit is contained in:
89
bayesnet/ensembles/BoostA2DE.cc
Normal file
89
bayesnet/ensembles/BoostA2DE.cc
Normal file
@@ -0,0 +1,89 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include <set>
|
||||
#include <functional>
|
||||
#include <limits.h>
|
||||
#include <tuple>
|
||||
#include <folding.hpp>
|
||||
#include "bayesnet/feature_selection/CFS.h"
|
||||
#include "bayesnet/feature_selection/FCBF.h"
|
||||
#include "bayesnet/feature_selection/IWSS.h"
|
||||
#include "BoostA2DE.h"
|
||||
|
||||
namespace bayesnet {
|
||||
|
||||
BoostA2DE::BoostA2DE(bool predict_voting) : Ensemble(predict_voting)
|
||||
{
|
||||
validHyperparameters = {
|
||||
"maxModels", "bisection", "order", "convergence", "convergence_best", "threshold",
|
||||
"select_features", "maxTolerance", "predict_voting", "block_update"
|
||||
};
|
||||
|
||||
}
|
||||
void BoostA2DE::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
models.clear();
|
||||
|
||||
}
|
||||
void BoostA2DE::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("order")) {
|
||||
std::vector<std::string> algos = { Orders.ASC, Orders.DESC, Orders.RAND };
|
||||
order_algorithm = hyperparameters["order"];
|
||||
if (std::find(algos.begin(), algos.end(), order_algorithm) == algos.end()) {
|
||||
throw std::invalid_argument("Invalid order algorithm, valid values [" + Orders.ASC + ", " + Orders.DESC + ", " + Orders.RAND + "]");
|
||||
}
|
||||
hyperparameters.erase("order");
|
||||
}
|
||||
if (hyperparameters.contains("convergence")) {
|
||||
convergence = hyperparameters["convergence"];
|
||||
hyperparameters.erase("convergence");
|
||||
}
|
||||
if (hyperparameters.contains("convergence_best")) {
|
||||
convergence_best = hyperparameters["convergence_best"];
|
||||
hyperparameters.erase("convergence_best");
|
||||
}
|
||||
if (hyperparameters.contains("bisection")) {
|
||||
bisection = hyperparameters["bisection"];
|
||||
hyperparameters.erase("bisection");
|
||||
}
|
||||
if (hyperparameters.contains("threshold")) {
|
||||
threshold = hyperparameters["threshold"];
|
||||
hyperparameters.erase("threshold");
|
||||
}
|
||||
if (hyperparameters.contains("maxTolerance")) {
|
||||
maxTolerance = hyperparameters["maxTolerance"];
|
||||
if (maxTolerance < 1 || maxTolerance > 4)
|
||||
throw std::invalid_argument("Invalid maxTolerance value, must be greater in [1, 4]");
|
||||
hyperparameters.erase("maxTolerance");
|
||||
}
|
||||
if (hyperparameters.contains("predict_voting")) {
|
||||
predict_voting = hyperparameters["predict_voting"];
|
||||
hyperparameters.erase("predict_voting");
|
||||
}
|
||||
if (hyperparameters.contains("select_features")) {
|
||||
auto selectedAlgorithm = hyperparameters["select_features"];
|
||||
std::vector<std::string> algos = { SelectFeatures.IWSS, SelectFeatures.CFS, SelectFeatures.FCBF };
|
||||
selectFeatures = true;
|
||||
select_features_algorithm = selectedAlgorithm;
|
||||
if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) {
|
||||
throw std::invalid_argument("Invalid selectFeatures value, valid values [" + SelectFeatures.IWSS + ", " + SelectFeatures.CFS + ", " + SelectFeatures.FCBF + "]");
|
||||
}
|
||||
hyperparameters.erase("select_features");
|
||||
}
|
||||
if (hyperparameters.contains("block_update")) {
|
||||
block_update = hyperparameters["block_update"];
|
||||
hyperparameters.erase("block_update");
|
||||
}
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
}
|
||||
std::vector<std::string> BoostA2DE::graph(const std::string& title) const
|
||||
{
|
||||
return Ensemble::graph(title);
|
||||
}
|
||||
}
|
38
bayesnet/ensembles/BoostA2DE.h
Normal file
38
bayesnet/ensembles/BoostA2DE.h
Normal file
@@ -0,0 +1,38 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#ifndef BOOSTA2DE_H
|
||||
#define BOOSTA2DE_H
|
||||
#include <map>
|
||||
#include "boost.h"
|
||||
#include "bayesnet/classifiers/SPnDE.h"
|
||||
#include "bayesnet/feature_selection/FeatureSelect.h"
|
||||
#include "Ensemble.h"
|
||||
namespace bayesnet {
|
||||
class BoostA2DE : public Ensemble {
|
||||
public:
|
||||
explicit BoostA2DE(bool predict_voting = false);
|
||||
virtual ~BoostA2DE() = default;
|
||||
std::vector<std::string> graph(const std::string& title = "BoostA2DE") const override;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
private:
|
||||
torch::Tensor X_train, y_train, X_test, y_test;
|
||||
// Hyperparameters
|
||||
bool bisection = true; // if true, use bisection stratety to add k models at once to the ensemble
|
||||
int maxTolerance = 3;
|
||||
std::string order_algorithm; // order to process the KBest features asc, desc, rand
|
||||
bool convergence = true; //if true, stop when the model does not improve
|
||||
bool convergence_best = false; // wether to keep the best accuracy to the moment or the last accuracy as prior accuracy
|
||||
bool selectFeatures = false; // if true, use feature selection
|
||||
std::string select_features_algorithm = Orders.DESC; // Selected feature selection algorithm
|
||||
FeatureSelect* featureSelector = nullptr;
|
||||
double threshold = -1;
|
||||
bool block_update = false;
|
||||
};
|
||||
}
|
||||
#endif
|
@@ -9,18 +9,9 @@
|
||||
#include <map>
|
||||
#include "bayesnet/classifiers/SPODE.h"
|
||||
#include "bayesnet/feature_selection/FeatureSelect.h"
|
||||
#include "boost.h"
|
||||
#include "Ensemble.h"
|
||||
namespace bayesnet {
|
||||
const struct {
|
||||
std::string CFS = "CFS";
|
||||
std::string FCBF = "FCBF";
|
||||
std::string IWSS = "IWSS";
|
||||
}SelectFeatures;
|
||||
const struct {
|
||||
std::string ASC = "asc";
|
||||
std::string DESC = "desc";
|
||||
std::string RAND = "rand";
|
||||
}Orders;
|
||||
class BoostAODE : public Ensemble {
|
||||
public:
|
||||
explicit BoostAODE(bool predict_voting = false);
|
||||
|
13
bayesnet/ensembles/boost.h
Normal file
13
bayesnet/ensembles/boost.h
Normal file
@@ -0,0 +1,13 @@
|
||||
#ifndef BOOST_H
|
||||
#define BOOST_H
|
||||
const struct {
|
||||
std::string CFS = "CFS";
|
||||
std::string FCBF = "FCBF";
|
||||
std::string IWSS = "IWSS";
|
||||
}SelectFeatures;
|
||||
const struct {
|
||||
std::string ASC = "asc";
|
||||
std::string DESC = "desc";
|
||||
std::string RAND = "rand";
|
||||
}Orders;
|
||||
#endif
|
Reference in New Issue
Block a user