Create BoostA2DE base class

This commit is contained in:
2024-05-15 11:53:17 +02:00
parent ef3c74633c
commit 1f236a70db
7 changed files with 356 additions and 12 deletions

View File

@@ -0,0 +1,89 @@
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#include <set>
#include <functional>
#include <limits.h>
#include <tuple>
#include <folding.hpp>
#include "bayesnet/feature_selection/CFS.h"
#include "bayesnet/feature_selection/FCBF.h"
#include "bayesnet/feature_selection/IWSS.h"
#include "BoostA2DE.h"
namespace bayesnet {
BoostA2DE::BoostA2DE(bool predict_voting) : Ensemble(predict_voting)
{
validHyperparameters = {
"maxModels", "bisection", "order", "convergence", "convergence_best", "threshold",
"select_features", "maxTolerance", "predict_voting", "block_update"
};
}
void BoostA2DE::buildModel(const torch::Tensor& weights)
{
models.clear();
}
void BoostA2DE::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("order")) {
std::vector<std::string> algos = { Orders.ASC, Orders.DESC, Orders.RAND };
order_algorithm = hyperparameters["order"];
if (std::find(algos.begin(), algos.end(), order_algorithm) == algos.end()) {
throw std::invalid_argument("Invalid order algorithm, valid values [" + Orders.ASC + ", " + Orders.DESC + ", " + Orders.RAND + "]");
}
hyperparameters.erase("order");
}
if (hyperparameters.contains("convergence")) {
convergence = hyperparameters["convergence"];
hyperparameters.erase("convergence");
}
if (hyperparameters.contains("convergence_best")) {
convergence_best = hyperparameters["convergence_best"];
hyperparameters.erase("convergence_best");
}
if (hyperparameters.contains("bisection")) {
bisection = hyperparameters["bisection"];
hyperparameters.erase("bisection");
}
if (hyperparameters.contains("threshold")) {
threshold = hyperparameters["threshold"];
hyperparameters.erase("threshold");
}
if (hyperparameters.contains("maxTolerance")) {
maxTolerance = hyperparameters["maxTolerance"];
if (maxTolerance < 1 || maxTolerance > 4)
throw std::invalid_argument("Invalid maxTolerance value, must be greater in [1, 4]");
hyperparameters.erase("maxTolerance");
}
if (hyperparameters.contains("predict_voting")) {
predict_voting = hyperparameters["predict_voting"];
hyperparameters.erase("predict_voting");
}
if (hyperparameters.contains("select_features")) {
auto selectedAlgorithm = hyperparameters["select_features"];
std::vector<std::string> algos = { SelectFeatures.IWSS, SelectFeatures.CFS, SelectFeatures.FCBF };
selectFeatures = true;
select_features_algorithm = selectedAlgorithm;
if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) {
throw std::invalid_argument("Invalid selectFeatures value, valid values [" + SelectFeatures.IWSS + ", " + SelectFeatures.CFS + ", " + SelectFeatures.FCBF + "]");
}
hyperparameters.erase("select_features");
}
if (hyperparameters.contains("block_update")) {
block_update = hyperparameters["block_update"];
hyperparameters.erase("block_update");
}
Classifier::setHyperparameters(hyperparameters);
}
std::vector<std::string> BoostA2DE::graph(const std::string& title) const
{
return Ensemble::graph(title);
}
}

View File

@@ -0,0 +1,38 @@
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#ifndef BOOSTA2DE_H
#define BOOSTA2DE_H
#include <map>
#include "boost.h"
#include "bayesnet/classifiers/SPnDE.h"
#include "bayesnet/feature_selection/FeatureSelect.h"
#include "Ensemble.h"
namespace bayesnet {
class BoostA2DE : public Ensemble {
public:
explicit BoostA2DE(bool predict_voting = false);
virtual ~BoostA2DE() = default;
std::vector<std::string> graph(const std::string& title = "BoostA2DE") const override;
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
protected:
void buildModel(const torch::Tensor& weights) override;
private:
torch::Tensor X_train, y_train, X_test, y_test;
// Hyperparameters
bool bisection = true; // if true, use bisection stratety to add k models at once to the ensemble
int maxTolerance = 3;
std::string order_algorithm; // order to process the KBest features asc, desc, rand
bool convergence = true; //if true, stop when the model does not improve
bool convergence_best = false; // wether to keep the best accuracy to the moment or the last accuracy as prior accuracy
bool selectFeatures = false; // if true, use feature selection
std::string select_features_algorithm = Orders.DESC; // Selected feature selection algorithm
FeatureSelect* featureSelector = nullptr;
double threshold = -1;
bool block_update = false;
};
}
#endif

View File

@@ -9,18 +9,9 @@
#include <map>
#include "bayesnet/classifiers/SPODE.h"
#include "bayesnet/feature_selection/FeatureSelect.h"
#include "boost.h"
#include "Ensemble.h"
namespace bayesnet {
const struct {
std::string CFS = "CFS";
std::string FCBF = "FCBF";
std::string IWSS = "IWSS";
}SelectFeatures;
const struct {
std::string ASC = "asc";
std::string DESC = "desc";
std::string RAND = "rand";
}Orders;
class BoostAODE : public Ensemble {
public:
explicit BoostAODE(bool predict_voting = false);

View File

@@ -0,0 +1,13 @@
#ifndef BOOST_H
#define BOOST_H
const struct {
std::string CFS = "CFS";
std::string FCBF = "FCBF";
std::string IWSS = "IWSS";
}SelectFeatures;
const struct {
std::string ASC = "asc";
std::string DESC = "desc";
std::string RAND = "rand";
}Orders;
#endif