495 lines
17 KiB
C++
495 lines
17 KiB
C++
// ***************************************************************
|
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
|
// SPDX-FileType: SOURCE
|
|
// SPDX-License-Identifier: MIT
|
|
// ***************************************************************
|
|
|
|
#include "DecisionTree.h"
|
|
#include <algorithm>
|
|
#include <numeric>
|
|
#include <sstream>
|
|
#include <iomanip>
|
|
#include <limits>
|
|
#include "TensorUtils.hpp"
|
|
|
|
namespace bayesnet {
|
|
|
|
DecisionTree::DecisionTree(int max_depth, int min_samples_split, int min_samples_leaf)
|
|
: Classifier(Network()), max_depth(max_depth),
|
|
min_samples_split(min_samples_split), min_samples_leaf(min_samples_leaf)
|
|
{
|
|
validHyperparameters = { "max_depth", "min_samples_split", "min_samples_leaf" };
|
|
}
|
|
|
|
void DecisionTree::setHyperparameters(const nlohmann::json& hyperparameters_)
|
|
{
|
|
auto hyperparameters = hyperparameters_;
|
|
// Set hyperparameters from JSON
|
|
auto it = hyperparameters.find("max_depth");
|
|
if (it != hyperparameters.end()) {
|
|
max_depth = it->get<int>();
|
|
hyperparameters.erase("max_depth"); // Remove 'order' if present
|
|
}
|
|
|
|
it = hyperparameters.find("min_samples_split");
|
|
if (it != hyperparameters.end()) {
|
|
min_samples_split = it->get<int>();
|
|
hyperparameters.erase("min_samples_split"); // Remove 'min_samples_split' if present
|
|
}
|
|
|
|
it = hyperparameters.find("min_samples_leaf");
|
|
if (it != hyperparameters.end()) {
|
|
min_samples_leaf = it->get<int>();
|
|
hyperparameters.erase("min_samples_leaf"); // Remove 'min_samples_leaf' if present
|
|
}
|
|
Classifier::setHyperparameters(hyperparameters);
|
|
checkValues();
|
|
}
|
|
void DecisionTree::checkValues()
|
|
{
|
|
if (max_depth <= 0) {
|
|
throw std::invalid_argument("max_depth must be positive");
|
|
}
|
|
if (min_samples_leaf <= 0) {
|
|
throw std::invalid_argument("min_samples_leaf must be positive");
|
|
}
|
|
if (min_samples_split <= 0) {
|
|
throw std::invalid_argument("min_samples_split must be positive");
|
|
}
|
|
}
|
|
void DecisionTree::buildModel(const torch::Tensor& weights)
|
|
{
|
|
// Extract features (X) and labels (y) from dataset
|
|
auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }).t();
|
|
auto y = dataset.index({ -1, torch::indexing::Slice() });
|
|
|
|
if (X.size(0) != y.size(0)) {
|
|
throw std::runtime_error("X and y must have the same number of samples");
|
|
}
|
|
|
|
n_classes = states[className].size();
|
|
|
|
// Use provided weights or uniform weights
|
|
torch::Tensor sample_weights;
|
|
if (weights.defined() && weights.numel() > 0) {
|
|
if (weights.size(0) != X.size(0)) {
|
|
throw std::runtime_error("weights must have the same length as number of samples");
|
|
}
|
|
sample_weights = weights;
|
|
} else {
|
|
sample_weights = torch::ones({ X.size(0) }) / X.size(0);
|
|
}
|
|
|
|
// Normalize weights
|
|
sample_weights = sample_weights / sample_weights.sum();
|
|
|
|
// Build the tree
|
|
root = buildTree(X, y, sample_weights, 0);
|
|
|
|
// Mark as fitted
|
|
fitted = true;
|
|
}
|
|
bool DecisionTree::validateTensors(const torch::Tensor& X, const torch::Tensor& y,
|
|
const torch::Tensor& sample_weights) const
|
|
{
|
|
if (X.size(0) != y.size(0) || X.size(0) != sample_weights.size(0)) {
|
|
return false;
|
|
}
|
|
if (X.size(0) == 0) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
std::unique_ptr<TreeNode> DecisionTree::buildTree(
|
|
const torch::Tensor& X,
|
|
const torch::Tensor& y,
|
|
const torch::Tensor& sample_weights,
|
|
int current_depth)
|
|
{
|
|
auto node = std::make_unique<TreeNode>();
|
|
int n_samples = y.size(0);
|
|
|
|
// Check stopping criteria
|
|
auto unique = at::_unique(y);
|
|
bool should_stop = (current_depth >= max_depth) ||
|
|
(n_samples < min_samples_split) ||
|
|
(std::get<0>(unique).size(0) == 1); // All samples same class
|
|
|
|
if (should_stop || n_samples <= min_samples_leaf) {
|
|
// Create leaf node
|
|
node->is_leaf = true;
|
|
|
|
// Calculate class probabilities
|
|
node->class_probabilities = torch::zeros({ n_classes });
|
|
|
|
for (int i = 0; i < n_samples; i++) {
|
|
int class_idx = y[i].item<int>();
|
|
node->class_probabilities[class_idx] += sample_weights[i].item<float>();
|
|
}
|
|
|
|
// Normalize probabilities
|
|
node->class_probabilities /= node->class_probabilities.sum();
|
|
|
|
// Set predicted class as the one with highest probability
|
|
node->predicted_class = torch::argmax(node->class_probabilities).item<int>();
|
|
|
|
return node;
|
|
}
|
|
|
|
// Find best split
|
|
SplitInfo best_split = findBestSplit(X, y, sample_weights);
|
|
|
|
// If no valid split found, create leaf
|
|
if (best_split.feature_index == -1 || best_split.impurity_decrease <= 0) {
|
|
node->is_leaf = true;
|
|
|
|
// Calculate class probabilities
|
|
node->class_probabilities = torch::zeros({ n_classes });
|
|
|
|
for (int i = 0; i < n_samples; i++) {
|
|
int class_idx = y[i].item<int>();
|
|
node->class_probabilities[class_idx] += sample_weights[i].item<float>();
|
|
}
|
|
|
|
node->class_probabilities /= node->class_probabilities.sum();
|
|
node->predicted_class = torch::argmax(node->class_probabilities).item<int>();
|
|
|
|
return node;
|
|
}
|
|
|
|
// Create internal node
|
|
node->is_leaf = false;
|
|
node->split_feature = best_split.feature_index;
|
|
node->split_value = best_split.split_value;
|
|
|
|
// Split data
|
|
auto left_X = X.index({ best_split.left_mask });
|
|
auto left_y = y.index({ best_split.left_mask });
|
|
auto left_weights = sample_weights.index({ best_split.left_mask });
|
|
|
|
auto right_X = X.index({ best_split.right_mask });
|
|
auto right_y = y.index({ best_split.right_mask });
|
|
auto right_weights = sample_weights.index({ best_split.right_mask });
|
|
|
|
// Recursively build subtrees
|
|
if (left_X.size(0) >= min_samples_leaf) {
|
|
node->left = buildTree(left_X, left_y, left_weights, current_depth + 1);
|
|
} else {
|
|
// Force leaf if not enough samples
|
|
node->left = std::make_unique<TreeNode>();
|
|
node->left->is_leaf = true;
|
|
auto mode = std::get<0>(torch::mode(left_y));
|
|
node->left->predicted_class = mode.item<int>();
|
|
node->left->class_probabilities = torch::zeros({ n_classes });
|
|
node->left->class_probabilities[node->left->predicted_class] = 1.0;
|
|
}
|
|
|
|
if (right_X.size(0) >= min_samples_leaf) {
|
|
node->right = buildTree(right_X, right_y, right_weights, current_depth + 1);
|
|
} else {
|
|
// Force leaf if not enough samples
|
|
node->right = std::make_unique<TreeNode>();
|
|
node->right->is_leaf = true;
|
|
auto mode = std::get<0>(torch::mode(right_y));
|
|
node->right->predicted_class = mode.item<int>();
|
|
node->right->class_probabilities = torch::zeros({ n_classes });
|
|
node->right->class_probabilities[node->right->predicted_class] = 1.0;
|
|
}
|
|
|
|
return node;
|
|
}
|
|
|
|
DecisionTree::SplitInfo DecisionTree::findBestSplit(
|
|
const torch::Tensor& X,
|
|
const torch::Tensor& y,
|
|
const torch::Tensor& sample_weights)
|
|
{
|
|
|
|
SplitInfo best_split;
|
|
best_split.feature_index = -1;
|
|
best_split.split_value = -1;
|
|
best_split.impurity_decrease = -std::numeric_limits<double>::infinity();
|
|
|
|
int n_features = X.size(1);
|
|
int n_samples = X.size(0);
|
|
|
|
// Calculate impurity of current node
|
|
double current_impurity = calculateGiniImpurity(y, sample_weights);
|
|
double total_weight = sample_weights.sum().item<double>();
|
|
|
|
// Try each feature
|
|
for (int feat_idx = 0; feat_idx < n_features; feat_idx++) {
|
|
auto feature_values = X.index({ torch::indexing::Slice(), feat_idx });
|
|
auto unique_values = std::get<0>(torch::unique_consecutive(std::get<0>(torch::sort(feature_values))));
|
|
|
|
// Try each unique value as split point
|
|
for (int i = 0; i < unique_values.size(0); i++) {
|
|
int split_val = unique_values[i].item<int>();
|
|
|
|
// Create masks for left and right splits
|
|
auto left_mask = feature_values == split_val;
|
|
auto right_mask = ~left_mask;
|
|
|
|
int left_count = left_mask.sum().item<int>();
|
|
int right_count = right_mask.sum().item<int>();
|
|
|
|
// Skip if split doesn't satisfy minimum samples requirement
|
|
if (left_count < min_samples_leaf || right_count < min_samples_leaf) {
|
|
continue;
|
|
}
|
|
|
|
// Calculate weighted impurities
|
|
auto left_y = y.index({ left_mask });
|
|
auto left_weights = sample_weights.index({ left_mask });
|
|
double left_weight = left_weights.sum().item<double>();
|
|
double left_impurity = calculateGiniImpurity(left_y, left_weights);
|
|
|
|
auto right_y = y.index({ right_mask });
|
|
auto right_weights = sample_weights.index({ right_mask });
|
|
double right_weight = right_weights.sum().item<double>();
|
|
double right_impurity = calculateGiniImpurity(right_y, right_weights);
|
|
|
|
// Calculate impurity decrease
|
|
double impurity_decrease = current_impurity -
|
|
(left_weight / total_weight * left_impurity +
|
|
right_weight / total_weight * right_impurity);
|
|
|
|
// Update best split if this is better
|
|
if (impurity_decrease > best_split.impurity_decrease) {
|
|
best_split.feature_index = feat_idx;
|
|
best_split.split_value = split_val;
|
|
best_split.impurity_decrease = impurity_decrease;
|
|
best_split.left_mask = left_mask;
|
|
best_split.right_mask = right_mask;
|
|
}
|
|
}
|
|
}
|
|
|
|
return best_split;
|
|
}
|
|
|
|
double DecisionTree::calculateGiniImpurity(
|
|
const torch::Tensor& y,
|
|
const torch::Tensor& sample_weights)
|
|
{
|
|
if (y.size(0) == 0 || sample_weights.size(0) == 0) {
|
|
return 0.0;
|
|
}
|
|
|
|
if (y.size(0) != sample_weights.size(0)) {
|
|
throw std::runtime_error("y and sample_weights must have same size");
|
|
}
|
|
|
|
torch::Tensor class_weights = torch::zeros({ n_classes });
|
|
|
|
// Calculate weighted class counts
|
|
for (int i = 0; i < y.size(0); i++) {
|
|
int class_idx = y[i].item<int>();
|
|
|
|
if (class_idx < 0 || class_idx >= n_classes) {
|
|
throw std::runtime_error("Invalid class index: " + std::to_string(class_idx));
|
|
}
|
|
|
|
class_weights[class_idx] += sample_weights[i].item<float>();
|
|
}
|
|
|
|
// Normalize
|
|
double total_weight = class_weights.sum().item<double>();
|
|
if (total_weight == 0) return 0.0;
|
|
|
|
class_weights /= total_weight;
|
|
|
|
// Calculate Gini impurity: 1 - sum(p_i^2)
|
|
double gini = 1.0;
|
|
for (int i = 0; i < n_classes; i++) {
|
|
double p = class_weights[i].item<double>();
|
|
gini -= p * p;
|
|
}
|
|
|
|
return gini;
|
|
}
|
|
|
|
|
|
torch::Tensor DecisionTree::predict(torch::Tensor& X)
|
|
{
|
|
if (!fitted) {
|
|
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
|
|
}
|
|
|
|
int n_samples = X.size(1);
|
|
torch::Tensor predictions = torch::zeros({ n_samples }, torch::kInt32);
|
|
|
|
for (int i = 0; i < n_samples; i++) {
|
|
auto sample = X.index({ torch::indexing::Slice(), i }).ravel();
|
|
predictions[i] = predictSample(sample);
|
|
}
|
|
|
|
return predictions;
|
|
}
|
|
|
|
std::vector<int> DecisionTree::predict(std::vector<std::vector<int>>& X)
|
|
{
|
|
// Convert to tensor
|
|
long n = X.size();
|
|
long m = X.at(0).size();
|
|
torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X);
|
|
auto predictions = predict(X_tensor);
|
|
std::vector<int> result = platform::TensorUtils::to_vector<int>(predictions);
|
|
return result;
|
|
}
|
|
|
|
torch::Tensor DecisionTree::predict_proba(torch::Tensor& X)
|
|
{
|
|
if (!fitted) {
|
|
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
|
|
}
|
|
|
|
int n_samples = X.size(1);
|
|
torch::Tensor probabilities = torch::zeros({ n_samples, n_classes });
|
|
|
|
for (int i = 0; i < n_samples; i++) {
|
|
auto sample = X.index({ torch::indexing::Slice(), i }).ravel();
|
|
probabilities[i] = predictProbaSample(sample);
|
|
}
|
|
|
|
return probabilities;
|
|
}
|
|
|
|
std::vector<std::vector<double>> DecisionTree::predict_proba(std::vector<std::vector<int>>& X)
|
|
{
|
|
auto n_samples = X.at(0).size();
|
|
// Convert to tensor
|
|
torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X);
|
|
auto proba_tensor = predict_proba(X_tensor);
|
|
std::vector<std::vector<double>> result(n_samples, std::vector<double>(n_classes, 0.0));
|
|
|
|
for (int i = 0; i < n_samples; i++) {
|
|
for (int j = 0; j < n_classes; j++) {
|
|
result[i][j] = proba_tensor[i][j].item<double>();
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
int DecisionTree::predictSample(const torch::Tensor& x) const
|
|
{
|
|
if (!fitted) {
|
|
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
|
|
}
|
|
|
|
if (x.size(0) != n) { // n debería ser el número de características
|
|
throw std::runtime_error("Input sample has wrong number of features");
|
|
}
|
|
|
|
const TreeNode* leaf = traverseTree(x, root.get());
|
|
return leaf->predicted_class;
|
|
}
|
|
torch::Tensor DecisionTree::predictProbaSample(const torch::Tensor& x) const
|
|
{
|
|
const TreeNode* leaf = traverseTree(x, root.get());
|
|
return leaf->class_probabilities.clone();
|
|
}
|
|
|
|
|
|
const TreeNode* DecisionTree::traverseTree(const torch::Tensor& x, const TreeNode* node) const
|
|
{
|
|
if (!node) {
|
|
throw std::runtime_error("Null node encountered during tree traversal");
|
|
}
|
|
|
|
if (node->is_leaf) {
|
|
return node;
|
|
}
|
|
|
|
if (node->split_feature < 0 || node->split_feature >= x.size(0)) {
|
|
throw std::runtime_error("Invalid split_feature index: " + std::to_string(node->split_feature));
|
|
}
|
|
|
|
int feature_value = x[node->split_feature].item<int>();
|
|
|
|
if (feature_value == node->split_value) {
|
|
if (!node->left) {
|
|
throw std::runtime_error("Missing left child in tree");
|
|
}
|
|
return traverseTree(x, node->left.get());
|
|
} else {
|
|
if (!node->right) {
|
|
throw std::runtime_error("Missing right child in tree");
|
|
}
|
|
return traverseTree(x, node->right.get());
|
|
}
|
|
}
|
|
|
|
std::vector<std::string> DecisionTree::graph(const std::string& title) const
|
|
{
|
|
std::vector<std::string> lines;
|
|
lines.push_back("digraph DecisionTree {");
|
|
lines.push_back(" rankdir=TB;");
|
|
lines.push_back(" node [shape=box, style=\"filled, rounded\", fontname=\"helvetica\"];");
|
|
lines.push_back(" edge [fontname=\"helvetica\"];");
|
|
|
|
if (!title.empty()) {
|
|
lines.push_back(" label=\"" + title + "\";");
|
|
lines.push_back(" labelloc=t;");
|
|
}
|
|
|
|
if (root) {
|
|
int node_id = 0;
|
|
treeToGraph(root.get(), lines, node_id);
|
|
}
|
|
|
|
lines.push_back("}");
|
|
return lines;
|
|
}
|
|
|
|
void DecisionTree::treeToGraph(
|
|
const TreeNode* node,
|
|
std::vector<std::string>& lines,
|
|
int& node_id,
|
|
int parent_id,
|
|
const std::string& edge_label) const
|
|
{
|
|
|
|
int current_id = node_id++;
|
|
std::stringstream ss;
|
|
|
|
if (node->is_leaf) {
|
|
// Leaf node
|
|
ss << " node" << current_id << " [label=\"Class: " << node->predicted_class;
|
|
ss << "\\nProb: " << std::fixed << std::setprecision(3)
|
|
<< node->class_probabilities[node->predicted_class].item<float>();
|
|
ss << "\", fillcolor=\"lightblue\"];";
|
|
lines.push_back(ss.str());
|
|
} else {
|
|
// Internal node
|
|
ss << " node" << current_id << " [label=\"" << features[node->split_feature];
|
|
ss << " = " << node->split_value << "?\", fillcolor=\"lightgreen\"];";
|
|
lines.push_back(ss.str());
|
|
}
|
|
|
|
// Add edge from parent
|
|
if (parent_id >= 0) {
|
|
ss.str("");
|
|
ss << " node" << parent_id << " -> node" << current_id;
|
|
if (!edge_label.empty()) {
|
|
ss << " [label=\"" << edge_label << "\"];";
|
|
} else {
|
|
ss << ";";
|
|
}
|
|
lines.push_back(ss.str());
|
|
}
|
|
|
|
// Recurse on children
|
|
if (!node->is_leaf) {
|
|
if (node->left) {
|
|
treeToGraph(node->left.get(), lines, node_id, current_id, "Yes");
|
|
}
|
|
if (node->right) {
|
|
treeToGraph(node->right.get(), lines, node_id, current_id, "No");
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace bayesnet
|