Add DecisionTree with tests

This commit is contained in:
2025-06-17 13:48:11 +02:00
parent 8c413a1eb0
commit 023d5613b4
16 changed files with 1272 additions and 81 deletions

View File

@@ -20,6 +20,8 @@ add_executable(
results/Result.cpp
experimental_clfs/XA1DE.cpp
experimental_clfs/ExpClf.cpp
experimental_clfs/DecisionTree.cpp
)
target_link_libraries(b_best Boost::boost "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}")
@@ -33,7 +35,7 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources}
results/Result.cpp
experimental_clfs/XA1DE.cpp
experimental_clfs/ExpClf.cpp
experimental_clfs/AdaBoost.cpp
experimental_clfs/DecisionTree.cpp
)
target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy)
@@ -45,6 +47,8 @@ add_executable(b_list commands/b_list.cpp
results/Result.cpp results/ResultsDatasetExcel.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp
experimental_clfs/XA1DE.cpp
experimental_clfs/ExpClf.cpp
experimental_clfs/DecisionTree.cpp
)
target_link_libraries(b_list "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}")
@@ -58,6 +62,8 @@ add_executable(b_main commands/b_main.cpp ${main_sources}
experimental_clfs/XA1DE.cpp
experimental_clfs/ExpClf.cpp
experimental_clfs/ExpClf.cpp
experimental_clfs/DecisionTree.cpp
)
target_link_libraries(b_main PRIVATE nlohmann_json::nlohmann_json "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy)

View File

@@ -5,18 +5,19 @@
// ***************************************************************
#include "AdaBoost.h"
#include "DecisionTree.h"
#include <cmath>
#include <algorithm>
#include <numeric>
#include <sstream>
#include <iomanip>
namespace platform {
namespace bayesnet {
AdaBoost::AdaBoost(int n_estimators)
: Ensemble(true), n_estimators(n_estimators)
AdaBoost::AdaBoost(int n_estimators, int max_depth)
: Ensemble(true), n_estimators(n_estimators), base_max_depth(max_depth)
{
validHyperparameters = { "n_estimators" };
validHyperparameters = { "n_estimators", "base_max_depth" };
}
void AdaBoost::buildModel(const torch::Tensor& weights)
@@ -89,20 +90,14 @@ namespace platform {
std::unique_ptr<Classifier> AdaBoost::trainBaseEstimator(const torch::Tensor& weights)
{
// Create a new classifier instance
// You need to implement this based on your specific base classifier
// For example, if using Decision Trees:
// auto classifier = std::make_unique<DecisionTree>();
// Create a decision tree with specified max depth
// For AdaBoost, we typically use shallow trees (stumps with max_depth=1)
auto tree = std::make_unique<DecisionTree>(base_max_depth);
// Or if using a factory method:
// auto classifier = ClassifierFactory::create("DecisionTree");
// Fit the tree with the current sample weights
tree->fit(dataset, features, className, states, weights, Smoothing_t::NONE);
// Placeholder - replace with actual classifier creation
throw std::runtime_error("AdaBoost::trainBaseEstimator - You need to implement base classifier creation");
// Once you have the classifier creation implemented, uncomment:
// classifier->fit(dataset, features, className, states, weights, Smoothing_t::NONE);
// return classifier;
return tree;
}
double AdaBoost::calculateWeightedError(Classifier* estimator, const torch::Tensor& weights)
@@ -192,8 +187,9 @@ namespace platform {
return graph_lines;
}
void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters)
void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
// Set hyperparameters from JSON
auto it = hyperparameters.find("n_estimators");
if (it != hyperparameters.end()) {
@@ -201,14 +197,18 @@ namespace platform {
if (n_estimators <= 0) {
throw std::invalid_argument("n_estimators must be positive");
}
hyperparameters.erase("n_estimators"); // Remove 'n_estimators' if present
}
// Check for invalid hyperparameters
for (auto& [key, value] : hyperparameters.items()) {
if (std::find(validHyperparameters.begin(), validHyperparameters.end(), key) == validHyperparameters.end()) {
throw std::invalid_argument("Invalid hyperparameter: " + key);
it = hyperparameters.find("base_max_depth");
if (it != hyperparameters.end()) {
base_max_depth = it->get<int>();
if (base_max_depth <= 0) {
throw std::invalid_argument("base_max_depth must be positive");
}
hyperparameters.erase("base_max_depth"); // Remove 'base_max_depth' if present
}
Ensemble::setHyperparameters(hyperparameters);
}
} // namespace bayesnet

View File

@@ -9,13 +9,12 @@
#include <vector>
#include <memory>
#include <torch/torch.h>
#include <bayesnet/ensembles/Ensemble.h>
#include "bayesnet/ensembles/Ensemble.h"
namespace platform {
class AdaBoost : public bayesnet::Ensemble {
namespace bayesnet {
class AdaBoost : public Ensemble {
public:
explicit AdaBoost(int n_estimators = 100);
explicit AdaBoost(int n_estimators = 50, int max_depth = 1);
virtual ~AdaBoost() = default;
// Override base class methods
@@ -24,10 +23,15 @@ namespace platform {
// AdaBoost specific methods
void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; }
int getNEstimators() const { return n_estimators; }
void setBaseMaxDepth(int depth) { this->base_max_depth = depth; }
int getBaseMaxDepth() const { return base_max_depth; }
// Get the weight of each base estimator
std::vector<double> getEstimatorWeights() const { return alphas; }
// Get training errors for each iteration
std::vector<double> getTrainingErrors() const { return training_errors; }
// Override setHyperparameters from BaseClassifier
void setHyperparameters(const nlohmann::json& hyperparameters) override;
@@ -37,6 +41,7 @@ namespace platform {
private:
int n_estimators;
int base_max_depth; // Max depth for base decision trees
std::vector<double> alphas; // Weight of each base estimator
std::vector<double> training_errors; // Training error at each iteration
torch::Tensor sample_weights; // Current sample weights

View File

@@ -0,0 +1,519 @@
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#include "DecisionTree.h"
#include <algorithm>
#include <numeric>
#include <sstream>
#include <iomanip>
#include <limits>
#include "TensorUtils.hpp"
namespace bayesnet {
DecisionTree::DecisionTree(int max_depth, int min_samples_split, int min_samples_leaf)
: Classifier(Network()), max_depth(max_depth),
min_samples_split(min_samples_split), min_samples_leaf(min_samples_leaf)
{
validHyperparameters = { "max_depth", "min_samples_split", "min_samples_leaf" };
}
void DecisionTree::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
// Set hyperparameters from JSON
auto it = hyperparameters.find("max_depth");
if (it != hyperparameters.end()) {
max_depth = it->get<int>();
hyperparameters.erase("max_depth"); // Remove 'order' if present
}
it = hyperparameters.find("min_samples_split");
if (it != hyperparameters.end()) {
min_samples_split = it->get<int>();
hyperparameters.erase("min_samples_split"); // Remove 'min_samples_split' if present
}
it = hyperparameters.find("min_samples_leaf");
if (it != hyperparameters.end()) {
min_samples_leaf = it->get<int>();
hyperparameters.erase("min_samples_leaf"); // Remove 'min_samples_leaf' if present
}
Classifier::setHyperparameters(hyperparameters);
checkValues();
}
void DecisionTree::checkValues()
{
if (max_depth <= 0) {
throw std::invalid_argument("max_depth must be positive");
}
if (min_samples_leaf <= 0) {
throw std::invalid_argument("min_samples_leaf must be positive");
}
if (min_samples_split <= 0) {
throw std::invalid_argument("min_samples_split must be positive");
}
}
void DecisionTree::buildModel(const torch::Tensor& weights)
{
// Extract features (X) and labels (y) from dataset
auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }).t();
auto y = dataset.index({ -1, torch::indexing::Slice() });
if (X.size(0) != y.size(0)) {
throw std::runtime_error("X and y must have the same number of samples");
}
n_classes = states[className].size();
// Use provided weights or uniform weights
torch::Tensor sample_weights;
if (weights.defined() && weights.numel() > 0) {
if (weights.size(0) != X.size(0)) {
throw std::runtime_error("weights must have the same length as number of samples");
}
sample_weights = weights;
} else {
sample_weights = torch::ones({ X.size(0) }) / X.size(0);
}
// Normalize weights
sample_weights = sample_weights / sample_weights.sum();
// Build the tree
root = buildTree(X, y, sample_weights, 0);
// Mark as fitted
fitted = true;
}
bool DecisionTree::validateTensors(const torch::Tensor& X, const torch::Tensor& y,
const torch::Tensor& sample_weights) const
{
if (X.size(0) != y.size(0) || X.size(0) != sample_weights.size(0)) {
return false;
}
if (X.size(0) == 0) {
return false;
}
return true;
}
std::unique_ptr<TreeNode> DecisionTree::buildTree(
const torch::Tensor& X,
const torch::Tensor& y,
const torch::Tensor& sample_weights,
int current_depth)
{
auto node = std::make_unique<TreeNode>();
int n_samples = y.size(0);
// Check stopping criteria
auto unique = at::_unique(y);
bool should_stop = (current_depth >= max_depth) ||
(n_samples < min_samples_split) ||
(std::get<0>(unique).size(0) == 1); // All samples same class
if (should_stop || n_samples <= min_samples_leaf) {
// Create leaf node
node->is_leaf = true;
// Calculate class probabilities
node->class_probabilities = torch::zeros({ n_classes });
for (int i = 0; i < n_samples; i++) {
int class_idx = y[i].item<int>();
node->class_probabilities[class_idx] += sample_weights[i].item<float>();
}
// Normalize probabilities
node->class_probabilities /= node->class_probabilities.sum();
// Set predicted class as the one with highest probability
node->predicted_class = torch::argmax(node->class_probabilities).item<int>();
return node;
}
// Find best split
SplitInfo best_split = findBestSplit(X, y, sample_weights);
// If no valid split found, create leaf
if (best_split.feature_index == -1 || best_split.impurity_decrease <= 0) {
node->is_leaf = true;
// Calculate class probabilities
node->class_probabilities = torch::zeros({ n_classes });
for (int i = 0; i < n_samples; i++) {
int class_idx = y[i].item<int>();
node->class_probabilities[class_idx] += sample_weights[i].item<float>();
}
node->class_probabilities /= node->class_probabilities.sum();
node->predicted_class = torch::argmax(node->class_probabilities).item<int>();
return node;
}
// Create internal node
node->is_leaf = false;
node->split_feature = best_split.feature_index;
node->split_value = best_split.split_value;
// Split data
auto left_X = X.index({ best_split.left_mask });
auto left_y = y.index({ best_split.left_mask });
auto left_weights = sample_weights.index({ best_split.left_mask });
auto right_X = X.index({ best_split.right_mask });
auto right_y = y.index({ best_split.right_mask });
auto right_weights = sample_weights.index({ best_split.right_mask });
// Recursively build subtrees
if (left_X.size(0) >= min_samples_leaf) {
node->left = buildTree(left_X, left_y, left_weights, current_depth + 1);
} else {
// Force leaf if not enough samples
node->left = std::make_unique<TreeNode>();
node->left->is_leaf = true;
auto mode = std::get<0>(torch::mode(left_y));
node->left->predicted_class = mode.item<int>();
node->left->class_probabilities = torch::zeros({ n_classes });
node->left->class_probabilities[node->left->predicted_class] = 1.0;
}
if (right_X.size(0) >= min_samples_leaf) {
node->right = buildTree(right_X, right_y, right_weights, current_depth + 1);
} else {
// Force leaf if not enough samples
node->right = std::make_unique<TreeNode>();
node->right->is_leaf = true;
auto mode = std::get<0>(torch::mode(right_y));
node->right->predicted_class = mode.item<int>();
node->right->class_probabilities = torch::zeros({ n_classes });
node->right->class_probabilities[node->right->predicted_class] = 1.0;
}
return node;
}
DecisionTree::SplitInfo DecisionTree::findBestSplit(
const torch::Tensor& X,
const torch::Tensor& y,
const torch::Tensor& sample_weights)
{
SplitInfo best_split;
best_split.feature_index = -1;
best_split.split_value = -1;
best_split.impurity_decrease = -std::numeric_limits<double>::infinity();
int n_features = X.size(1);
int n_samples = X.size(0);
// Calculate impurity of current node
double current_impurity = calculateGiniImpurity(y, sample_weights);
double total_weight = sample_weights.sum().item<double>();
// Try each feature
for (int feat_idx = 0; feat_idx < n_features; feat_idx++) {
auto feature_values = X.index({ torch::indexing::Slice(), feat_idx });
auto unique_values = std::get<0>(torch::unique_consecutive(std::get<0>(torch::sort(feature_values))));
// Try each unique value as split point
for (int i = 0; i < unique_values.size(0); i++) {
int split_val = unique_values[i].item<int>();
// Create masks for left and right splits
auto left_mask = feature_values == split_val;
auto right_mask = ~left_mask;
int left_count = left_mask.sum().item<int>();
int right_count = right_mask.sum().item<int>();
// Skip if split doesn't satisfy minimum samples requirement
if (left_count < min_samples_leaf || right_count < min_samples_leaf) {
continue;
}
// Calculate weighted impurities
auto left_y = y.index({ left_mask });
auto left_weights = sample_weights.index({ left_mask });
double left_weight = left_weights.sum().item<double>();
double left_impurity = calculateGiniImpurity(left_y, left_weights);
auto right_y = y.index({ right_mask });
auto right_weights = sample_weights.index({ right_mask });
double right_weight = right_weights.sum().item<double>();
double right_impurity = calculateGiniImpurity(right_y, right_weights);
// Calculate impurity decrease
double impurity_decrease = current_impurity -
(left_weight / total_weight * left_impurity +
right_weight / total_weight * right_impurity);
// Update best split if this is better
if (impurity_decrease > best_split.impurity_decrease) {
best_split.feature_index = feat_idx;
best_split.split_value = split_val;
best_split.impurity_decrease = impurity_decrease;
best_split.left_mask = left_mask;
best_split.right_mask = right_mask;
}
}
}
return best_split;
}
double DecisionTree::calculateGiniImpurity(
const torch::Tensor& y,
const torch::Tensor& sample_weights)
{
if (y.size(0) == 0 || sample_weights.size(0) == 0) {
return 0.0;
}
if (y.size(0) != sample_weights.size(0)) {
throw std::runtime_error("y and sample_weights must have same size");
}
torch::Tensor class_weights = torch::zeros({ n_classes });
// Calculate weighted class counts
for (int i = 0; i < y.size(0); i++) {
int class_idx = y[i].item<int>();
if (class_idx < 0 || class_idx >= n_classes) {
throw std::runtime_error("Invalid class index: " + std::to_string(class_idx));
}
class_weights[class_idx] += sample_weights[i].item<float>();
}
// Normalize
double total_weight = class_weights.sum().item<double>();
if (total_weight == 0) return 0.0;
class_weights /= total_weight;
// Calculate Gini impurity: 1 - sum(p_i^2)
double gini = 1.0;
for (int i = 0; i < n_classes; i++) {
double p = class_weights[i].item<double>();
gini -= p * p;
}
return gini;
}
torch::Tensor DecisionTree::predict(torch::Tensor& X)
{
if (!fitted) {
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
}
int n_samples = X.size(1);
torch::Tensor predictions = torch::zeros({ n_samples }, torch::kInt32);
for (int i = 0; i < n_samples; i++) {
auto sample = X.index({ torch::indexing::Slice(), i }).ravel();
predictions[i] = predictSample(sample);
}
return predictions;
}
void dumpTensor(const torch::Tensor& tensor, const std::string& name)
{
std::cout << name << ": " << std::endl;
for (int i = 0; i < tensor.size(0); i++) {
std::cout << "[";
for (int j = 0; j < tensor.size(1); j++) {
std::cout << tensor[i][j].item<int>() << " ";
}
std::cout << "]" << std::endl;
}
std::cout << std::endl;
}
void dumpVector(const std::vector<std::vector<int>>& vec, const std::string& name)
{
std::cout << name << ": " << std::endl;;
for (const auto& row : vec) {
std::cout << "[";
for (const auto& val : row) {
std::cout << val << " ";
}
std::cout << "] " << std::endl;
}
std::cout << std::endl;
}
std::vector<int> DecisionTree::predict(std::vector<std::vector<int>>& X)
{
// Convert to tensor
long n = X.size();
long m = X.at(0).size();
torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X);
auto predictions = predict(X_tensor);
std::vector<int> result = platform::TensorUtils::to_vector<int>(predictions);
return result;
}
torch::Tensor DecisionTree::predict_proba(torch::Tensor& X)
{
if (!fitted) {
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
}
int n_samples = X.size(1);
torch::Tensor probabilities = torch::zeros({ n_samples, n_classes });
for (int i = 0; i < n_samples; i++) {
auto sample = X.index({ torch::indexing::Slice(), i }).ravel();
probabilities[i] = predictProbaSample(sample);
}
return probabilities;
}
std::vector<std::vector<double>> DecisionTree::predict_proba(std::vector<std::vector<int>>& X)
{
auto n_samples = X.at(0).size();
// Convert to tensor
torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X);
auto proba_tensor = predict_proba(X_tensor);
std::vector<std::vector<double>> result(n_samples, std::vector<double>(n_classes, 0.0));
for (int i = 0; i < n_samples; i++) {
for (int j = 0; j < n_classes; j++) {
result[i][j] = proba_tensor[i][j].item<double>();
}
}
return result;
}
int DecisionTree::predictSample(const torch::Tensor& x) const
{
if (!fitted) {
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
}
if (x.size(0) != n) { // n debería ser el número de características
throw std::runtime_error("Input sample has wrong number of features");
}
const TreeNode* leaf = traverseTree(x, root.get());
return leaf->predicted_class;
}
torch::Tensor DecisionTree::predictProbaSample(const torch::Tensor& x) const
{
const TreeNode* leaf = traverseTree(x, root.get());
return leaf->class_probabilities.clone();
}
const TreeNode* DecisionTree::traverseTree(const torch::Tensor& x, const TreeNode* node) const
{
if (!node) {
throw std::runtime_error("Null node encountered during tree traversal");
}
if (node->is_leaf) {
return node;
}
if (node->split_feature < 0 || node->split_feature >= x.size(0)) {
throw std::runtime_error("Invalid split_feature index: " + std::to_string(node->split_feature));
}
int feature_value = x[node->split_feature].item<int>();
if (feature_value == node->split_value) {
if (!node->left) {
throw std::runtime_error("Missing left child in tree");
}
return traverseTree(x, node->left.get());
} else {
if (!node->right) {
throw std::runtime_error("Missing right child in tree");
}
return traverseTree(x, node->right.get());
}
}
std::vector<std::string> DecisionTree::graph(const std::string& title) const
{
std::vector<std::string> lines;
lines.push_back("digraph DecisionTree {");
lines.push_back(" rankdir=TB;");
lines.push_back(" node [shape=box, style=\"filled, rounded\", fontname=\"helvetica\"];");
lines.push_back(" edge [fontname=\"helvetica\"];");
if (!title.empty()) {
lines.push_back(" label=\"" + title + "\";");
lines.push_back(" labelloc=t;");
}
if (root) {
int node_id = 0;
treeToGraph(root.get(), lines, node_id);
}
lines.push_back("}");
return lines;
}
void DecisionTree::treeToGraph(
const TreeNode* node,
std::vector<std::string>& lines,
int& node_id,
int parent_id,
const std::string& edge_label) const
{
int current_id = node_id++;
std::stringstream ss;
if (node->is_leaf) {
// Leaf node
ss << " node" << current_id << " [label=\"Class: " << node->predicted_class;
ss << "\\nProb: " << std::fixed << std::setprecision(3)
<< node->class_probabilities[node->predicted_class].item<float>();
ss << "\", fillcolor=\"lightblue\"];";
lines.push_back(ss.str());
} else {
// Internal node
ss << " node" << current_id << " [label=\"" << features[node->split_feature];
ss << " = " << node->split_value << "?\", fillcolor=\"lightgreen\"];";
lines.push_back(ss.str());
}
// Add edge from parent
if (parent_id >= 0) {
ss.str("");
ss << " node" << parent_id << " -> node" << current_id;
if (!edge_label.empty()) {
ss << " [label=\"" << edge_label << "\"];";
} else {
ss << ";";
}
lines.push_back(ss.str());
}
// Recurse on children
if (!node->is_leaf) {
if (node->left) {
treeToGraph(node->left.get(), lines, node_id, current_id, "Yes");
}
if (node->right) {
treeToGraph(node->right.get(), lines, node_id, current_id, "No");
}
}
}
} // namespace bayesnet

View File

@@ -0,0 +1,129 @@
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#ifndef DECISION_TREE_H
#define DECISION_TREE_H
#include <memory>
#include <vector>
#include <map>
#include <torch/torch.h>
#include "bayesnet/classifiers/Classifier.h"
namespace bayesnet {
// Forward declaration
struct TreeNode;
class DecisionTree : public Classifier {
public:
explicit DecisionTree(int max_depth = 3, int min_samples_split = 2, int min_samples_leaf = 1);
virtual ~DecisionTree() = default;
// Override graph method to show tree structure
std::vector<std::string> graph(const std::string& title = "") const override;
// Setters for hyperparameters
void setMaxDepth(int depth) { max_depth = depth; checkValues(); }
void setMinSamplesSplit(int samples) { min_samples_split = samples; checkValues(); }
void setMinSamplesLeaf(int samples) { min_samples_leaf = samples; checkValues(); }
// Override setHyperparameters
void setHyperparameters(const nlohmann::json& hyperparameters) override;
torch::Tensor predict(torch::Tensor& X) override;
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
torch::Tensor predict_proba(torch::Tensor& X) override;
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X);
protected:
void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override
{
// Decision trees do not require training in the traditional sense
// as they are built from the data directly.
// This method can be used to set weights or other parameters if needed.
}
private:
void checkValues();
bool validateTensors(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& sample_weights) const;
// Tree hyperparameters
int max_depth;
int min_samples_split;
int min_samples_leaf;
int n_classes; // Number of classes in the target variable
// Root of the decision tree
std::unique_ptr<TreeNode> root;
// Build tree recursively
std::unique_ptr<TreeNode> buildTree(
const torch::Tensor& X,
const torch::Tensor& y,
const torch::Tensor& sample_weights,
int current_depth
);
// Find best split for a node
struct SplitInfo {
int feature_index;
int split_value;
double impurity_decrease;
torch::Tensor left_mask;
torch::Tensor right_mask;
};
SplitInfo findBestSplit(
const torch::Tensor& X,
const torch::Tensor& y,
const torch::Tensor& sample_weights
);
// Calculate weighted Gini impurity for multi-class
double calculateGiniImpurity(
const torch::Tensor& y,
const torch::Tensor& sample_weights
);
// Make predictions for a single sample
int predictSample(const torch::Tensor& x) const;
// Make probabilistic predictions for a single sample
torch::Tensor predictProbaSample(const torch::Tensor& x) const;
// Traverse tree to find leaf node
const TreeNode* traverseTree(const torch::Tensor& x, const TreeNode* node) const;
// Convert tree to graph representation
void treeToGraph(
const TreeNode* node,
std::vector<std::string>& lines,
int& node_id,
int parent_id = -1,
const std::string& edge_label = ""
) const;
};
// Tree node structure
struct TreeNode {
bool is_leaf;
// For internal nodes
int split_feature;
int split_value;
std::unique_ptr<TreeNode> left;
std::unique_ptr<TreeNode> right;
// For leaf nodes
int predicted_class;
torch::Tensor class_probabilities; // Probability for each class
TreeNode() : is_leaf(false), split_feature(-1), split_value(-1), predicted_class(-1) {}
};
} // namespace bayesnet
#endif // DECISION_TREE_H

View File

@@ -0,0 +1,142 @@
# AdaBoost and DecisionTree Classifier Implementation
This implementation provides both a Decision Tree classifier and a multi-class AdaBoost classifier based on the SAMME (Stagewise Additive Modeling using a Multi-class Exponential loss) algorithm described in the paper "Multi-class AdaBoost" by Zhu et al. Implemented in C++ using <https://claude.ai>
## Components
### 1. DecisionTree Classifier
A classic decision tree implementation that:
- Supports multi-class classification
- Handles weighted samples (essential for boosting)
- Uses Gini impurity as the splitting criterion
- Works with discrete/categorical features
- Provides both class predictions and probability estimates
#### Key Features
- **Max Depth Control**: Limit tree depth to create weak learners
- **Minimum Samples**: Control minimum samples for splitting and leaf nodes
- **Weighted Training**: Properly handles sample weights for boosting
- **Visualization**: Generates DOT format graphs of the tree structure
#### Hyperparameters
- `max_depth`: Maximum depth of the tree (default: 3)
- `min_samples_split`: Minimum samples required to split a node (default: 2)
- `min_samples_leaf`: Minimum samples required in a leaf node (default: 1)
### 2. AdaBoost Classifier
A multi-class AdaBoost implementation using DecisionTree as base estimators:
- **SAMME Algorithm**: Implements the multi-class extension of AdaBoost
- **Automatic Stumps**: Uses decision stumps (max_depth=1) by default
- **Early Stopping**: Stops if base classifier performs worse than random
- **Ensemble Visualization**: Shows the weighted combination of base estimators
#### Key Features
- **Multi-class Support**: Natural extension to K classes
- **Base Estimator Control**: Configure depth of base decision trees
- **Training Monitoring**: Track training errors and estimator weights
- **Probability Estimates**: Provides class probability predictions
#### Hyperparameters
- `n_estimators`: Number of base estimators to train (default: 50)
- `base_max_depth`: Maximum depth for base decision trees (default: 1)
## Algorithm Details
The SAMME algorithm differs from binary AdaBoost in the calculation of the estimator weight (alpha):
```
α = log((1 - err) / err) + log(K - 1)
```
where `K` is the number of classes. This formula ensures that:
- When K = 2, it reduces to standard AdaBoost
- For K > 2, base classifiers only need to be better than random guessing (1/K) rather than 50%
## Usage Example
```cpp
// Create AdaBoost with decision stumps
AdaBoost ada(100, 1); // 100 estimators, max_depth=1
// Train
ada.fit(X_train, y_train, features, className, states, Smoothing_t::NONE);
// Predict
auto predictions = ada.predict(X_test);
auto probabilities = ada.predict_proba(X_test);
// Evaluate
float accuracy = ada.score(X_test, y_test);
// Get ensemble information
auto weights = ada.getEstimatorWeights();
auto errors = ada.getTrainingErrors();
```
## Implementation Structure
```
AdaBoost (inherits from Ensemble)
└── Uses multiple DecisionTree instances as base estimators
└── DecisionTree (inherits from Classifier)
└── Implements weighted Gini impurity splitting
```
## Visualization
Both classifiers support graph visualization:
- **DecisionTree**: Shows the tree structure with split conditions
- **AdaBoost**: Shows the ensemble of weighted base estimators
Generate visualizations using:
```cpp
auto graph = classifier.graph("Title");
```
## Data Format
Both classifiers expect discrete/categorical data:
- **Features**: Integer values representing categories (stored in `torch::Tensor` or `std::vector<std::vector<int>>`)
- **Labels**: Integer values representing class indices (0, 1, ..., K-1)
- **States**: Map defining possible values for each feature and the class variable
- **Sample Weights**: Optional weights for each training sample (important for boosting)
Example data setup:
```cpp
// Features matrix (n_features x n_samples)
torch::Tensor X = torch::tensor({{0, 1, 2}, {1, 0, 1}}); // 2 features, 3 samples
// Labels vector
torch::Tensor y = torch::tensor({0, 1, 0}); // 3 samples
// States definition
std::map<std::string, std::vector<int>> states;
states["feature1"] = {0, 1, 2}; // Feature 1 can take values 0, 1, or 2
states["feature2"] = {0, 1}; // Feature 2 can take values 0 or 1
states["class"] = {0, 1}; // Binary classification
```
## Notes
- The implementation handles discrete/categorical features as indicated by the int-based data structures
- Sample weights are properly propagated through the tree building process
- The DecisionTree implementation uses equality testing for splits (suitable for categorical data)
- Both classifiers support the standard fit/predict interface from the base framework
## References
- Zhu, J., Zou, H., Rosset, S., & Hastie, T. (2009). Multi-class AdaBoost. Statistics and its interface, 2(3), 349-360.
- Breiman, L., Friedman, J., Olshen, R., & Stone, C. (1984). Classification and Regression Trees. Wadsworth, Belmont, CA.

View File

@@ -45,6 +45,19 @@ namespace platform {
return data;
}
static torch::Tensor to_matrix(const std::vector<std::vector<int>>& data)
{
if (data.empty()) return torch::empty({ 0, 0 }, torch::kInt64);
size_t rows = data.size();
size_t cols = data[0].size();
torch::Tensor tensor = torch::empty({ static_cast<long>(rows), static_cast<long>(cols) }, torch::kInt64);
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
tensor.index_put_({ static_cast<long>(i), static_cast<long>(j) }, data[i][j]);
}
}
return tensor;
}
};
}

View File

@@ -23,10 +23,11 @@
#include <pyclassifiers/ODTE.h>
#include <pyclassifiers/SVC.h>
#include <pyclassifiers/XGBoost.h>
#include <pyclassifiers/AdaBoost.h>
#include <pyclassifiers/AdaBoostPy.h>
#include <pyclassifiers/RandomForest.h>
#include "../experimental_clfs/XA1DE.h"
#include "../experimental_clfs/AdaBoost.h"
// #include "../experimental_clfs/AdaBoost.h"
#include "../experimental_clfs/DecisionTree.h"
namespace platform {
class Models {

View File

@@ -35,10 +35,12 @@ namespace platform {
[](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();});
static Registrar registrarXGB("XGBoost",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();});
static Registrar registrarAda("AdaBoostPy",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoost();});
// static Registrar registrarAda2("AdaBoost",
// [](void) -> bayesnet::BaseClassifier* { return new platform::AdaBoost();});
static Registrar registrarAdaPy("AdaBoostPy",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoostPy();});
// static Registrar registrarAda("AdaBoost",
// [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AdaBoost();});
static Registrar registrarDT("DecisionTree",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::DecisionTree();});
static Registrar registrarXSPODE("XSPODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpode(0);});
static Registrar registrarXSP2DE("XSP2DE",