Begin to add AdaBoost implementation
This commit is contained in:
@@ -33,6 +33,7 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources}
|
||||
results/Result.cpp
|
||||
experimental_clfs/XA1DE.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
experimental_clfs/AdaBoost.cpp
|
||||
)
|
||||
target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy)
|
||||
|
||||
@@ -56,6 +57,7 @@ add_executable(b_main commands/b_main.cpp ${main_sources}
|
||||
results/Result.cpp
|
||||
experimental_clfs/XA1DE.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
)
|
||||
target_link_libraries(b_main PRIVATE nlohmann_json::nlohmann_json "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy)
|
||||
|
||||
|
214
src/experimental_clfs/AdaBoost.cpp
Normal file
214
src/experimental_clfs/AdaBoost.cpp
Normal file
@@ -0,0 +1,214 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include "AdaBoost.h"
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
namespace platform {
|
||||
|
||||
AdaBoost::AdaBoost(int n_estimators)
|
||||
: Ensemble(true), n_estimators(n_estimators)
|
||||
{
|
||||
validHyperparameters = { "n_estimators" };
|
||||
}
|
||||
|
||||
void AdaBoost::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// Initialize variables
|
||||
models.clear();
|
||||
alphas.clear();
|
||||
training_errors.clear();
|
||||
|
||||
// Initialize sample weights uniformly
|
||||
int n_samples = dataset.size(1);
|
||||
sample_weights = torch::ones({ n_samples }) / n_samples;
|
||||
|
||||
// If initial weights are provided, incorporate them
|
||||
if (weights.defined() && weights.numel() > 0) {
|
||||
sample_weights *= weights;
|
||||
normalizeWeights();
|
||||
}
|
||||
|
||||
// Main AdaBoost training loop (SAMME algorithm)
|
||||
for (int iter = 0; iter < n_estimators; ++iter) {
|
||||
// Train base estimator with current sample weights
|
||||
auto estimator = trainBaseEstimator(sample_weights);
|
||||
|
||||
// Calculate weighted error
|
||||
double weighted_error = calculateWeightedError(estimator.get(), sample_weights);
|
||||
training_errors.push_back(weighted_error);
|
||||
|
||||
// Check if error is too high (worse than random guessing)
|
||||
double random_guess_error = 1.0 - (1.0 / getClassNumStates());
|
||||
if (weighted_error >= random_guess_error) {
|
||||
// If only one estimator and it's worse than random, keep it with zero weight
|
||||
if (models.empty()) {
|
||||
models.push_back(std::move(estimator));
|
||||
alphas.push_back(0.0);
|
||||
}
|
||||
break; // Stop boosting
|
||||
}
|
||||
|
||||
// Calculate alpha (estimator weight) using SAMME formula
|
||||
// alpha = log((1 - err) / err) + log(K - 1)
|
||||
double alpha = std::log((1.0 - weighted_error) / weighted_error) +
|
||||
std::log(static_cast<double>(getClassNumStates() - 1));
|
||||
|
||||
// Store the estimator and its weight
|
||||
models.push_back(std::move(estimator));
|
||||
alphas.push_back(alpha);
|
||||
|
||||
// Update sample weights
|
||||
updateSampleWeights(models.back().get(), alpha);
|
||||
|
||||
// Normalize weights
|
||||
normalizeWeights();
|
||||
|
||||
// Check for perfect classification
|
||||
if (weighted_error < 1e-10) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Set the number of models actually trained
|
||||
n_models = models.size();
|
||||
}
|
||||
|
||||
void AdaBoost::trainModel(const torch::Tensor& weights, const Smoothing_t smoothing)
|
||||
{
|
||||
// AdaBoost handles its own weight management, so we just build the model
|
||||
buildModel(weights);
|
||||
}
|
||||
|
||||
std::unique_ptr<Classifier> AdaBoost::trainBaseEstimator(const torch::Tensor& weights)
|
||||
{
|
||||
// Create a new classifier instance
|
||||
// You need to implement this based on your specific base classifier
|
||||
// For example, if using Decision Trees:
|
||||
// auto classifier = std::make_unique<DecisionTree>();
|
||||
|
||||
// Or if using a factory method:
|
||||
// auto classifier = ClassifierFactory::create("DecisionTree");
|
||||
|
||||
// Placeholder - replace with actual classifier creation
|
||||
throw std::runtime_error("AdaBoost::trainBaseEstimator - You need to implement base classifier creation");
|
||||
|
||||
// Once you have the classifier creation implemented, uncomment:
|
||||
// classifier->fit(dataset, features, className, states, weights, Smoothing_t::NONE);
|
||||
// return classifier;
|
||||
}
|
||||
|
||||
double AdaBoost::calculateWeightedError(Classifier* estimator, const torch::Tensor& weights)
|
||||
{
|
||||
// Get predictions from the estimator
|
||||
auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() });
|
||||
auto y_true = dataset.index({ -1, torch::indexing::Slice() });
|
||||
auto y_pred = estimator->predict(X.t());
|
||||
|
||||
// Calculate weighted error
|
||||
auto incorrect = (y_pred != y_true).to(torch::kFloat);
|
||||
double weighted_error = torch::sum(incorrect * weights).item<double>();
|
||||
|
||||
return weighted_error;
|
||||
}
|
||||
|
||||
void AdaBoost::updateSampleWeights(Classifier* estimator, double alpha)
|
||||
{
|
||||
// Get predictions from the estimator
|
||||
auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() });
|
||||
auto y_true = dataset.index({ -1, torch::indexing::Slice() });
|
||||
auto y_pred = estimator->predict(X.t());
|
||||
|
||||
// Update weights according to SAMME algorithm
|
||||
// w_i = w_i * exp(alpha * I(y_i != y_pred_i))
|
||||
auto incorrect = (y_pred != y_true).to(torch::kFloat);
|
||||
sample_weights *= torch::exp(alpha * incorrect);
|
||||
}
|
||||
|
||||
void AdaBoost::normalizeWeights()
|
||||
{
|
||||
// Normalize weights to sum to 1
|
||||
double sum_weights = torch::sum(sample_weights).item<double>();
|
||||
if (sum_weights > 0) {
|
||||
sample_weights /= sum_weights;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> AdaBoost::graph(const std::string& title) const
|
||||
{
|
||||
// Create a graph representation of the AdaBoost ensemble
|
||||
std::vector<std::string> graph_lines;
|
||||
|
||||
// Header
|
||||
graph_lines.push_back("digraph AdaBoost {");
|
||||
graph_lines.push_back(" rankdir=TB;");
|
||||
graph_lines.push_back(" node [shape=box];");
|
||||
|
||||
if (!title.empty()) {
|
||||
graph_lines.push_back(" label=\"" + title + "\";");
|
||||
graph_lines.push_back(" labelloc=t;");
|
||||
}
|
||||
|
||||
// Add input node
|
||||
graph_lines.push_back(" Input [shape=ellipse, label=\"Input Features\"];");
|
||||
|
||||
// Add base estimators
|
||||
for (size_t i = 0; i < models.size(); ++i) {
|
||||
std::stringstream ss;
|
||||
ss << " Estimator" << i << " [label=\"Base Estimator " << i + 1
|
||||
<< "\\nα = " << std::fixed << std::setprecision(3) << alphas[i] << "\"];";
|
||||
graph_lines.push_back(ss.str());
|
||||
|
||||
// Connect input to estimator
|
||||
ss.str("");
|
||||
ss << " Input -> Estimator" << i << ";";
|
||||
graph_lines.push_back(ss.str());
|
||||
}
|
||||
|
||||
// Add combination node
|
||||
graph_lines.push_back(" Combination [shape=diamond, label=\"Weighted Vote\"];");
|
||||
|
||||
// Connect estimators to combination
|
||||
for (size_t i = 0; i < models.size(); ++i) {
|
||||
std::stringstream ss;
|
||||
ss << " Estimator" << i << " -> Combination;";
|
||||
graph_lines.push_back(ss.str());
|
||||
}
|
||||
|
||||
// Add output node
|
||||
graph_lines.push_back(" Output [shape=ellipse, label=\"Final Prediction\"];");
|
||||
graph_lines.push_back(" Combination -> Output;");
|
||||
|
||||
// Close graph
|
||||
graph_lines.push_back("}");
|
||||
|
||||
return graph_lines;
|
||||
}
|
||||
|
||||
void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
{
|
||||
// Set hyperparameters from JSON
|
||||
auto it = hyperparameters.find("n_estimators");
|
||||
if (it != hyperparameters.end()) {
|
||||
n_estimators = it->get<int>();
|
||||
if (n_estimators <= 0) {
|
||||
throw std::invalid_argument("n_estimators must be positive");
|
||||
}
|
||||
}
|
||||
|
||||
// Check for invalid hyperparameters
|
||||
for (auto& [key, value] : hyperparameters.items()) {
|
||||
if (std::find(validHyperparameters.begin(), validHyperparameters.end(), key) == validHyperparameters.end()) {
|
||||
throw std::invalid_argument("Invalid hyperparameter: " + key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace bayesnet
|
58
src/experimental_clfs/AdaBoost.h
Normal file
58
src/experimental_clfs/AdaBoost.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#ifndef ADABOOST_H
|
||||
#define ADABOOST_H
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <torch/torch.h>
|
||||
#include <bayesnet/ensembles/Ensemble.h>
|
||||
|
||||
namespace platform {
|
||||
class AdaBoost : public bayesnet::Ensemble {
|
||||
public:
|
||||
explicit AdaBoost(int n_estimators = 100);
|
||||
virtual ~AdaBoost() = default;
|
||||
|
||||
// Override base class methods
|
||||
std::vector<std::string> graph(const std::string& title = "") const override;
|
||||
|
||||
// AdaBoost specific methods
|
||||
void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; }
|
||||
int getNEstimators() const { return n_estimators; }
|
||||
|
||||
// Get the weight of each base estimator
|
||||
std::vector<double> getEstimatorWeights() const { return alphas; }
|
||||
|
||||
// Override setHyperparameters from BaseClassifier
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override;
|
||||
|
||||
private:
|
||||
int n_estimators;
|
||||
std::vector<double> alphas; // Weight of each base estimator
|
||||
std::vector<double> training_errors; // Training error at each iteration
|
||||
torch::Tensor sample_weights; // Current sample weights
|
||||
|
||||
// Train a single base estimator
|
||||
std::unique_ptr<Classifier> trainBaseEstimator(const torch::Tensor& weights);
|
||||
|
||||
// Calculate weighted error
|
||||
double calculateWeightedError(Classifier* estimator, const torch::Tensor& weights);
|
||||
|
||||
// Update sample weights based on predictions
|
||||
void updateSampleWeights(Classifier* estimator, double alpha);
|
||||
|
||||
// Normalize weights to sum to 1
|
||||
void normalizeWeights();
|
||||
};
|
||||
}
|
||||
|
||||
#endif // ADABOOST_H
|
@@ -26,6 +26,7 @@
|
||||
#include <pyclassifiers/AdaBoost.h>
|
||||
#include <pyclassifiers/RandomForest.h>
|
||||
#include "../experimental_clfs/XA1DE.h"
|
||||
#include "../experimental_clfs/AdaBoost.h"
|
||||
|
||||
namespace platform {
|
||||
class Models {
|
||||
|
@@ -35,8 +35,10 @@ namespace platform {
|
||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();});
|
||||
static Registrar registrarXGB("XGBoost",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();});
|
||||
static Registrar registrarAda("AdaBoost",
|
||||
static Registrar registrarAda("AdaBoostPy",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoost();});
|
||||
// static Registrar registrarAda2("AdaBoost",
|
||||
// [](void) -> bayesnet::BaseClassifier* { return new platform::AdaBoost();});
|
||||
static Registrar registrarXSPODE("XSPODE",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpode(0);});
|
||||
static Registrar registrarXSP2DE("XSP2DE",
|
||||
|
Reference in New Issue
Block a user