Add DecisionTree with tests
This commit is contained in:
1
Makefile
1
Makefile
@@ -96,6 +96,7 @@ opt = ""
|
||||
test: ## Run tests (opt="-s") to verbose output the tests, (opt="-c='Test Maximum Spanning Tree'") to run only that section
|
||||
@echo ">>> Running Platform tests...";
|
||||
@$(MAKE) clean
|
||||
@$(MAKE) debug
|
||||
@cmake --build $(f_debug) -t $(test_targets) --parallel
|
||||
@for t in $(test_targets); do \
|
||||
if [ -f $(f_debug)/tests/$$t ]; then \
|
||||
|
@@ -20,6 +20,8 @@ add_executable(
|
||||
results/Result.cpp
|
||||
experimental_clfs/XA1DE.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
experimental_clfs/DecisionTree.cpp
|
||||
|
||||
)
|
||||
target_link_libraries(b_best Boost::boost "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}")
|
||||
|
||||
@@ -33,7 +35,7 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources}
|
||||
results/Result.cpp
|
||||
experimental_clfs/XA1DE.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
experimental_clfs/AdaBoost.cpp
|
||||
experimental_clfs/DecisionTree.cpp
|
||||
)
|
||||
target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy)
|
||||
|
||||
@@ -45,6 +47,8 @@ add_executable(b_list commands/b_list.cpp
|
||||
results/Result.cpp results/ResultsDatasetExcel.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp
|
||||
experimental_clfs/XA1DE.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
experimental_clfs/DecisionTree.cpp
|
||||
|
||||
)
|
||||
target_link_libraries(b_list "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}")
|
||||
|
||||
@@ -58,6 +62,8 @@ add_executable(b_main commands/b_main.cpp ${main_sources}
|
||||
experimental_clfs/XA1DE.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
experimental_clfs/ExpClf.cpp
|
||||
experimental_clfs/DecisionTree.cpp
|
||||
|
||||
)
|
||||
target_link_libraries(b_main PRIVATE nlohmann_json::nlohmann_json "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy)
|
||||
|
||||
|
@@ -5,18 +5,19 @@
|
||||
// ***************************************************************
|
||||
|
||||
#include "AdaBoost.h"
|
||||
#include "DecisionTree.h"
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
namespace platform {
|
||||
namespace bayesnet {
|
||||
|
||||
AdaBoost::AdaBoost(int n_estimators)
|
||||
: Ensemble(true), n_estimators(n_estimators)
|
||||
AdaBoost::AdaBoost(int n_estimators, int max_depth)
|
||||
: Ensemble(true), n_estimators(n_estimators), base_max_depth(max_depth)
|
||||
{
|
||||
validHyperparameters = { "n_estimators" };
|
||||
validHyperparameters = { "n_estimators", "base_max_depth" };
|
||||
}
|
||||
|
||||
void AdaBoost::buildModel(const torch::Tensor& weights)
|
||||
@@ -89,20 +90,14 @@ namespace platform {
|
||||
|
||||
std::unique_ptr<Classifier> AdaBoost::trainBaseEstimator(const torch::Tensor& weights)
|
||||
{
|
||||
// Create a new classifier instance
|
||||
// You need to implement this based on your specific base classifier
|
||||
// For example, if using Decision Trees:
|
||||
// auto classifier = std::make_unique<DecisionTree>();
|
||||
// Create a decision tree with specified max depth
|
||||
// For AdaBoost, we typically use shallow trees (stumps with max_depth=1)
|
||||
auto tree = std::make_unique<DecisionTree>(base_max_depth);
|
||||
|
||||
// Or if using a factory method:
|
||||
// auto classifier = ClassifierFactory::create("DecisionTree");
|
||||
// Fit the tree with the current sample weights
|
||||
tree->fit(dataset, features, className, states, weights, Smoothing_t::NONE);
|
||||
|
||||
// Placeholder - replace with actual classifier creation
|
||||
throw std::runtime_error("AdaBoost::trainBaseEstimator - You need to implement base classifier creation");
|
||||
|
||||
// Once you have the classifier creation implemented, uncomment:
|
||||
// classifier->fit(dataset, features, className, states, weights, Smoothing_t::NONE);
|
||||
// return classifier;
|
||||
return tree;
|
||||
}
|
||||
|
||||
double AdaBoost::calculateWeightedError(Classifier* estimator, const torch::Tensor& weights)
|
||||
@@ -192,8 +187,9 @@ namespace platform {
|
||||
return graph_lines;
|
||||
}
|
||||
|
||||
void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||
void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
// Set hyperparameters from JSON
|
||||
auto it = hyperparameters.find("n_estimators");
|
||||
if (it != hyperparameters.end()) {
|
||||
@@ -201,14 +197,18 @@ namespace platform {
|
||||
if (n_estimators <= 0) {
|
||||
throw std::invalid_argument("n_estimators must be positive");
|
||||
}
|
||||
hyperparameters.erase("n_estimators"); // Remove 'n_estimators' if present
|
||||
}
|
||||
|
||||
// Check for invalid hyperparameters
|
||||
for (auto& [key, value] : hyperparameters.items()) {
|
||||
if (std::find(validHyperparameters.begin(), validHyperparameters.end(), key) == validHyperparameters.end()) {
|
||||
throw std::invalid_argument("Invalid hyperparameter: " + key);
|
||||
it = hyperparameters.find("base_max_depth");
|
||||
if (it != hyperparameters.end()) {
|
||||
base_max_depth = it->get<int>();
|
||||
if (base_max_depth <= 0) {
|
||||
throw std::invalid_argument("base_max_depth must be positive");
|
||||
}
|
||||
hyperparameters.erase("base_max_depth"); // Remove 'base_max_depth' if present
|
||||
}
|
||||
Ensemble::setHyperparameters(hyperparameters);
|
||||
}
|
||||
|
||||
} // namespace bayesnet
|
@@ -9,13 +9,12 @@
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <torch/torch.h>
|
||||
#include <bayesnet/ensembles/Ensemble.h>
|
||||
#include "bayesnet/ensembles/Ensemble.h"
|
||||
|
||||
namespace platform {
|
||||
class AdaBoost : public bayesnet::Ensemble {
|
||||
namespace bayesnet {
|
||||
class AdaBoost : public Ensemble {
|
||||
public:
|
||||
explicit AdaBoost(int n_estimators = 100);
|
||||
explicit AdaBoost(int n_estimators = 50, int max_depth = 1);
|
||||
virtual ~AdaBoost() = default;
|
||||
|
||||
// Override base class methods
|
||||
@@ -24,10 +23,15 @@ namespace platform {
|
||||
// AdaBoost specific methods
|
||||
void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; }
|
||||
int getNEstimators() const { return n_estimators; }
|
||||
void setBaseMaxDepth(int depth) { this->base_max_depth = depth; }
|
||||
int getBaseMaxDepth() const { return base_max_depth; }
|
||||
|
||||
// Get the weight of each base estimator
|
||||
std::vector<double> getEstimatorWeights() const { return alphas; }
|
||||
|
||||
// Get training errors for each iteration
|
||||
std::vector<double> getTrainingErrors() const { return training_errors; }
|
||||
|
||||
// Override setHyperparameters from BaseClassifier
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
|
||||
@@ -37,6 +41,7 @@ namespace platform {
|
||||
|
||||
private:
|
||||
int n_estimators;
|
||||
int base_max_depth; // Max depth for base decision trees
|
||||
std::vector<double> alphas; // Weight of each base estimator
|
||||
std::vector<double> training_errors; // Training error at each iteration
|
||||
torch::Tensor sample_weights; // Current sample weights
|
||||
|
519
src/experimental_clfs/DecisionTree.cpp
Normal file
519
src/experimental_clfs/DecisionTree.cpp
Normal file
@@ -0,0 +1,519 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include "DecisionTree.h"
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include "TensorUtils.hpp"
|
||||
|
||||
namespace bayesnet {
|
||||
|
||||
DecisionTree::DecisionTree(int max_depth, int min_samples_split, int min_samples_leaf)
|
||||
: Classifier(Network()), max_depth(max_depth),
|
||||
min_samples_split(min_samples_split), min_samples_leaf(min_samples_leaf)
|
||||
{
|
||||
validHyperparameters = { "max_depth", "min_samples_split", "min_samples_leaf" };
|
||||
}
|
||||
|
||||
void DecisionTree::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
// Set hyperparameters from JSON
|
||||
auto it = hyperparameters.find("max_depth");
|
||||
if (it != hyperparameters.end()) {
|
||||
max_depth = it->get<int>();
|
||||
hyperparameters.erase("max_depth"); // Remove 'order' if present
|
||||
}
|
||||
|
||||
it = hyperparameters.find("min_samples_split");
|
||||
if (it != hyperparameters.end()) {
|
||||
min_samples_split = it->get<int>();
|
||||
hyperparameters.erase("min_samples_split"); // Remove 'min_samples_split' if present
|
||||
}
|
||||
|
||||
it = hyperparameters.find("min_samples_leaf");
|
||||
if (it != hyperparameters.end()) {
|
||||
min_samples_leaf = it->get<int>();
|
||||
hyperparameters.erase("min_samples_leaf"); // Remove 'min_samples_leaf' if present
|
||||
}
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
checkValues();
|
||||
}
|
||||
void DecisionTree::checkValues()
|
||||
{
|
||||
if (max_depth <= 0) {
|
||||
throw std::invalid_argument("max_depth must be positive");
|
||||
}
|
||||
if (min_samples_leaf <= 0) {
|
||||
throw std::invalid_argument("min_samples_leaf must be positive");
|
||||
}
|
||||
if (min_samples_split <= 0) {
|
||||
throw std::invalid_argument("min_samples_split must be positive");
|
||||
}
|
||||
}
|
||||
void DecisionTree::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// Extract features (X) and labels (y) from dataset
|
||||
auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }).t();
|
||||
auto y = dataset.index({ -1, torch::indexing::Slice() });
|
||||
|
||||
if (X.size(0) != y.size(0)) {
|
||||
throw std::runtime_error("X and y must have the same number of samples");
|
||||
}
|
||||
|
||||
n_classes = states[className].size();
|
||||
|
||||
// Use provided weights or uniform weights
|
||||
torch::Tensor sample_weights;
|
||||
if (weights.defined() && weights.numel() > 0) {
|
||||
if (weights.size(0) != X.size(0)) {
|
||||
throw std::runtime_error("weights must have the same length as number of samples");
|
||||
}
|
||||
sample_weights = weights;
|
||||
} else {
|
||||
sample_weights = torch::ones({ X.size(0) }) / X.size(0);
|
||||
}
|
||||
|
||||
// Normalize weights
|
||||
sample_weights = sample_weights / sample_weights.sum();
|
||||
|
||||
// Build the tree
|
||||
root = buildTree(X, y, sample_weights, 0);
|
||||
|
||||
// Mark as fitted
|
||||
fitted = true;
|
||||
}
|
||||
bool DecisionTree::validateTensors(const torch::Tensor& X, const torch::Tensor& y,
|
||||
const torch::Tensor& sample_weights) const
|
||||
{
|
||||
if (X.size(0) != y.size(0) || X.size(0) != sample_weights.size(0)) {
|
||||
return false;
|
||||
}
|
||||
if (X.size(0) == 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<TreeNode> DecisionTree::buildTree(
|
||||
const torch::Tensor& X,
|
||||
const torch::Tensor& y,
|
||||
const torch::Tensor& sample_weights,
|
||||
int current_depth)
|
||||
{
|
||||
auto node = std::make_unique<TreeNode>();
|
||||
int n_samples = y.size(0);
|
||||
|
||||
// Check stopping criteria
|
||||
auto unique = at::_unique(y);
|
||||
bool should_stop = (current_depth >= max_depth) ||
|
||||
(n_samples < min_samples_split) ||
|
||||
(std::get<0>(unique).size(0) == 1); // All samples same class
|
||||
|
||||
if (should_stop || n_samples <= min_samples_leaf) {
|
||||
// Create leaf node
|
||||
node->is_leaf = true;
|
||||
|
||||
// Calculate class probabilities
|
||||
node->class_probabilities = torch::zeros({ n_classes });
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
int class_idx = y[i].item<int>();
|
||||
node->class_probabilities[class_idx] += sample_weights[i].item<float>();
|
||||
}
|
||||
|
||||
// Normalize probabilities
|
||||
node->class_probabilities /= node->class_probabilities.sum();
|
||||
|
||||
// Set predicted class as the one with highest probability
|
||||
node->predicted_class = torch::argmax(node->class_probabilities).item<int>();
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
// Find best split
|
||||
SplitInfo best_split = findBestSplit(X, y, sample_weights);
|
||||
|
||||
// If no valid split found, create leaf
|
||||
if (best_split.feature_index == -1 || best_split.impurity_decrease <= 0) {
|
||||
node->is_leaf = true;
|
||||
|
||||
// Calculate class probabilities
|
||||
node->class_probabilities = torch::zeros({ n_classes });
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
int class_idx = y[i].item<int>();
|
||||
node->class_probabilities[class_idx] += sample_weights[i].item<float>();
|
||||
}
|
||||
|
||||
node->class_probabilities /= node->class_probabilities.sum();
|
||||
node->predicted_class = torch::argmax(node->class_probabilities).item<int>();
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
// Create internal node
|
||||
node->is_leaf = false;
|
||||
node->split_feature = best_split.feature_index;
|
||||
node->split_value = best_split.split_value;
|
||||
|
||||
// Split data
|
||||
auto left_X = X.index({ best_split.left_mask });
|
||||
auto left_y = y.index({ best_split.left_mask });
|
||||
auto left_weights = sample_weights.index({ best_split.left_mask });
|
||||
|
||||
auto right_X = X.index({ best_split.right_mask });
|
||||
auto right_y = y.index({ best_split.right_mask });
|
||||
auto right_weights = sample_weights.index({ best_split.right_mask });
|
||||
|
||||
// Recursively build subtrees
|
||||
if (left_X.size(0) >= min_samples_leaf) {
|
||||
node->left = buildTree(left_X, left_y, left_weights, current_depth + 1);
|
||||
} else {
|
||||
// Force leaf if not enough samples
|
||||
node->left = std::make_unique<TreeNode>();
|
||||
node->left->is_leaf = true;
|
||||
auto mode = std::get<0>(torch::mode(left_y));
|
||||
node->left->predicted_class = mode.item<int>();
|
||||
node->left->class_probabilities = torch::zeros({ n_classes });
|
||||
node->left->class_probabilities[node->left->predicted_class] = 1.0;
|
||||
}
|
||||
|
||||
if (right_X.size(0) >= min_samples_leaf) {
|
||||
node->right = buildTree(right_X, right_y, right_weights, current_depth + 1);
|
||||
} else {
|
||||
// Force leaf if not enough samples
|
||||
node->right = std::make_unique<TreeNode>();
|
||||
node->right->is_leaf = true;
|
||||
auto mode = std::get<0>(torch::mode(right_y));
|
||||
node->right->predicted_class = mode.item<int>();
|
||||
node->right->class_probabilities = torch::zeros({ n_classes });
|
||||
node->right->class_probabilities[node->right->predicted_class] = 1.0;
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
DecisionTree::SplitInfo DecisionTree::findBestSplit(
|
||||
const torch::Tensor& X,
|
||||
const torch::Tensor& y,
|
||||
const torch::Tensor& sample_weights)
|
||||
{
|
||||
|
||||
SplitInfo best_split;
|
||||
best_split.feature_index = -1;
|
||||
best_split.split_value = -1;
|
||||
best_split.impurity_decrease = -std::numeric_limits<double>::infinity();
|
||||
|
||||
int n_features = X.size(1);
|
||||
int n_samples = X.size(0);
|
||||
|
||||
// Calculate impurity of current node
|
||||
double current_impurity = calculateGiniImpurity(y, sample_weights);
|
||||
double total_weight = sample_weights.sum().item<double>();
|
||||
|
||||
// Try each feature
|
||||
for (int feat_idx = 0; feat_idx < n_features; feat_idx++) {
|
||||
auto feature_values = X.index({ torch::indexing::Slice(), feat_idx });
|
||||
auto unique_values = std::get<0>(torch::unique_consecutive(std::get<0>(torch::sort(feature_values))));
|
||||
|
||||
// Try each unique value as split point
|
||||
for (int i = 0; i < unique_values.size(0); i++) {
|
||||
int split_val = unique_values[i].item<int>();
|
||||
|
||||
// Create masks for left and right splits
|
||||
auto left_mask = feature_values == split_val;
|
||||
auto right_mask = ~left_mask;
|
||||
|
||||
int left_count = left_mask.sum().item<int>();
|
||||
int right_count = right_mask.sum().item<int>();
|
||||
|
||||
// Skip if split doesn't satisfy minimum samples requirement
|
||||
if (left_count < min_samples_leaf || right_count < min_samples_leaf) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Calculate weighted impurities
|
||||
auto left_y = y.index({ left_mask });
|
||||
auto left_weights = sample_weights.index({ left_mask });
|
||||
double left_weight = left_weights.sum().item<double>();
|
||||
double left_impurity = calculateGiniImpurity(left_y, left_weights);
|
||||
|
||||
auto right_y = y.index({ right_mask });
|
||||
auto right_weights = sample_weights.index({ right_mask });
|
||||
double right_weight = right_weights.sum().item<double>();
|
||||
double right_impurity = calculateGiniImpurity(right_y, right_weights);
|
||||
|
||||
// Calculate impurity decrease
|
||||
double impurity_decrease = current_impurity -
|
||||
(left_weight / total_weight * left_impurity +
|
||||
right_weight / total_weight * right_impurity);
|
||||
|
||||
// Update best split if this is better
|
||||
if (impurity_decrease > best_split.impurity_decrease) {
|
||||
best_split.feature_index = feat_idx;
|
||||
best_split.split_value = split_val;
|
||||
best_split.impurity_decrease = impurity_decrease;
|
||||
best_split.left_mask = left_mask;
|
||||
best_split.right_mask = right_mask;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return best_split;
|
||||
}
|
||||
|
||||
double DecisionTree::calculateGiniImpurity(
|
||||
const torch::Tensor& y,
|
||||
const torch::Tensor& sample_weights)
|
||||
{
|
||||
if (y.size(0) == 0 || sample_weights.size(0) == 0) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
if (y.size(0) != sample_weights.size(0)) {
|
||||
throw std::runtime_error("y and sample_weights must have same size");
|
||||
}
|
||||
|
||||
torch::Tensor class_weights = torch::zeros({ n_classes });
|
||||
|
||||
// Calculate weighted class counts
|
||||
for (int i = 0; i < y.size(0); i++) {
|
||||
int class_idx = y[i].item<int>();
|
||||
|
||||
if (class_idx < 0 || class_idx >= n_classes) {
|
||||
throw std::runtime_error("Invalid class index: " + std::to_string(class_idx));
|
||||
}
|
||||
|
||||
class_weights[class_idx] += sample_weights[i].item<float>();
|
||||
}
|
||||
|
||||
// Normalize
|
||||
double total_weight = class_weights.sum().item<double>();
|
||||
if (total_weight == 0) return 0.0;
|
||||
|
||||
class_weights /= total_weight;
|
||||
|
||||
// Calculate Gini impurity: 1 - sum(p_i^2)
|
||||
double gini = 1.0;
|
||||
for (int i = 0; i < n_classes; i++) {
|
||||
double p = class_weights[i].item<double>();
|
||||
gini -= p * p;
|
||||
}
|
||||
|
||||
return gini;
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor DecisionTree::predict(torch::Tensor& X)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
|
||||
}
|
||||
|
||||
int n_samples = X.size(1);
|
||||
torch::Tensor predictions = torch::zeros({ n_samples }, torch::kInt32);
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
auto sample = X.index({ torch::indexing::Slice(), i }).ravel();
|
||||
predictions[i] = predictSample(sample);
|
||||
}
|
||||
|
||||
return predictions;
|
||||
}
|
||||
void dumpTensor(const torch::Tensor& tensor, const std::string& name)
|
||||
{
|
||||
std::cout << name << ": " << std::endl;
|
||||
for (int i = 0; i < tensor.size(0); i++) {
|
||||
std::cout << "[";
|
||||
for (int j = 0; j < tensor.size(1); j++) {
|
||||
std::cout << tensor[i][j].item<int>() << " ";
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
void dumpVector(const std::vector<std::vector<int>>& vec, const std::string& name)
|
||||
{
|
||||
std::cout << name << ": " << std::endl;;
|
||||
for (const auto& row : vec) {
|
||||
std::cout << "[";
|
||||
for (const auto& val : row) {
|
||||
std::cout << val << " ";
|
||||
}
|
||||
std::cout << "] " << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
std::vector<int> DecisionTree::predict(std::vector<std::vector<int>>& X)
|
||||
{
|
||||
// Convert to tensor
|
||||
long n = X.size();
|
||||
long m = X.at(0).size();
|
||||
torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X);
|
||||
auto predictions = predict(X_tensor);
|
||||
std::vector<int> result = platform::TensorUtils::to_vector<int>(predictions);
|
||||
return result;
|
||||
}
|
||||
|
||||
torch::Tensor DecisionTree::predict_proba(torch::Tensor& X)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
|
||||
}
|
||||
|
||||
int n_samples = X.size(1);
|
||||
torch::Tensor probabilities = torch::zeros({ n_samples, n_classes });
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
auto sample = X.index({ torch::indexing::Slice(), i }).ravel();
|
||||
probabilities[i] = predictProbaSample(sample);
|
||||
}
|
||||
|
||||
return probabilities;
|
||||
}
|
||||
|
||||
std::vector<std::vector<double>> DecisionTree::predict_proba(std::vector<std::vector<int>>& X)
|
||||
{
|
||||
auto n_samples = X.at(0).size();
|
||||
// Convert to tensor
|
||||
torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X);
|
||||
auto proba_tensor = predict_proba(X_tensor);
|
||||
std::vector<std::vector<double>> result(n_samples, std::vector<double>(n_classes, 0.0));
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
for (int j = 0; j < n_classes; j++) {
|
||||
result[i][j] = proba_tensor[i][j].item<double>();
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
int DecisionTree::predictSample(const torch::Tensor& x) const
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
|
||||
}
|
||||
|
||||
if (x.size(0) != n) { // n debería ser el número de características
|
||||
throw std::runtime_error("Input sample has wrong number of features");
|
||||
}
|
||||
|
||||
const TreeNode* leaf = traverseTree(x, root.get());
|
||||
return leaf->predicted_class;
|
||||
}
|
||||
torch::Tensor DecisionTree::predictProbaSample(const torch::Tensor& x) const
|
||||
{
|
||||
const TreeNode* leaf = traverseTree(x, root.get());
|
||||
return leaf->class_probabilities.clone();
|
||||
}
|
||||
|
||||
|
||||
const TreeNode* DecisionTree::traverseTree(const torch::Tensor& x, const TreeNode* node) const
|
||||
{
|
||||
if (!node) {
|
||||
throw std::runtime_error("Null node encountered during tree traversal");
|
||||
}
|
||||
|
||||
if (node->is_leaf) {
|
||||
return node;
|
||||
}
|
||||
|
||||
if (node->split_feature < 0 || node->split_feature >= x.size(0)) {
|
||||
throw std::runtime_error("Invalid split_feature index: " + std::to_string(node->split_feature));
|
||||
}
|
||||
|
||||
int feature_value = x[node->split_feature].item<int>();
|
||||
|
||||
if (feature_value == node->split_value) {
|
||||
if (!node->left) {
|
||||
throw std::runtime_error("Missing left child in tree");
|
||||
}
|
||||
return traverseTree(x, node->left.get());
|
||||
} else {
|
||||
if (!node->right) {
|
||||
throw std::runtime_error("Missing right child in tree");
|
||||
}
|
||||
return traverseTree(x, node->right.get());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> DecisionTree::graph(const std::string& title) const
|
||||
{
|
||||
std::vector<std::string> lines;
|
||||
lines.push_back("digraph DecisionTree {");
|
||||
lines.push_back(" rankdir=TB;");
|
||||
lines.push_back(" node [shape=box, style=\"filled, rounded\", fontname=\"helvetica\"];");
|
||||
lines.push_back(" edge [fontname=\"helvetica\"];");
|
||||
|
||||
if (!title.empty()) {
|
||||
lines.push_back(" label=\"" + title + "\";");
|
||||
lines.push_back(" labelloc=t;");
|
||||
}
|
||||
|
||||
if (root) {
|
||||
int node_id = 0;
|
||||
treeToGraph(root.get(), lines, node_id);
|
||||
}
|
||||
|
||||
lines.push_back("}");
|
||||
return lines;
|
||||
}
|
||||
|
||||
void DecisionTree::treeToGraph(
|
||||
const TreeNode* node,
|
||||
std::vector<std::string>& lines,
|
||||
int& node_id,
|
||||
int parent_id,
|
||||
const std::string& edge_label) const
|
||||
{
|
||||
|
||||
int current_id = node_id++;
|
||||
std::stringstream ss;
|
||||
|
||||
if (node->is_leaf) {
|
||||
// Leaf node
|
||||
ss << " node" << current_id << " [label=\"Class: " << node->predicted_class;
|
||||
ss << "\\nProb: " << std::fixed << std::setprecision(3)
|
||||
<< node->class_probabilities[node->predicted_class].item<float>();
|
||||
ss << "\", fillcolor=\"lightblue\"];";
|
||||
lines.push_back(ss.str());
|
||||
} else {
|
||||
// Internal node
|
||||
ss << " node" << current_id << " [label=\"" << features[node->split_feature];
|
||||
ss << " = " << node->split_value << "?\", fillcolor=\"lightgreen\"];";
|
||||
lines.push_back(ss.str());
|
||||
}
|
||||
|
||||
// Add edge from parent
|
||||
if (parent_id >= 0) {
|
||||
ss.str("");
|
||||
ss << " node" << parent_id << " -> node" << current_id;
|
||||
if (!edge_label.empty()) {
|
||||
ss << " [label=\"" << edge_label << "\"];";
|
||||
} else {
|
||||
ss << ";";
|
||||
}
|
||||
lines.push_back(ss.str());
|
||||
}
|
||||
|
||||
// Recurse on children
|
||||
if (!node->is_leaf) {
|
||||
if (node->left) {
|
||||
treeToGraph(node->left.get(), lines, node_id, current_id, "Yes");
|
||||
}
|
||||
if (node->right) {
|
||||
treeToGraph(node->right.get(), lines, node_id, current_id, "No");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace bayesnet
|
129
src/experimental_clfs/DecisionTree.h
Normal file
129
src/experimental_clfs/DecisionTree.h
Normal file
@@ -0,0 +1,129 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#ifndef DECISION_TREE_H
|
||||
#define DECISION_TREE_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <torch/torch.h>
|
||||
#include "bayesnet/classifiers/Classifier.h"
|
||||
|
||||
namespace bayesnet {
|
||||
|
||||
// Forward declaration
|
||||
struct TreeNode;
|
||||
|
||||
class DecisionTree : public Classifier {
|
||||
public:
|
||||
explicit DecisionTree(int max_depth = 3, int min_samples_split = 2, int min_samples_leaf = 1);
|
||||
virtual ~DecisionTree() = default;
|
||||
|
||||
// Override graph method to show tree structure
|
||||
std::vector<std::string> graph(const std::string& title = "") const override;
|
||||
|
||||
// Setters for hyperparameters
|
||||
void setMaxDepth(int depth) { max_depth = depth; checkValues(); }
|
||||
void setMinSamplesSplit(int samples) { min_samples_split = samples; checkValues(); }
|
||||
void setMinSamplesLeaf(int samples) { min_samples_leaf = samples; checkValues(); }
|
||||
|
||||
// Override setHyperparameters
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
|
||||
torch::Tensor predict(torch::Tensor& X) override;
|
||||
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
||||
torch::Tensor predict_proba(torch::Tensor& X) override;
|
||||
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X);
|
||||
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override
|
||||
{
|
||||
// Decision trees do not require training in the traditional sense
|
||||
// as they are built from the data directly.
|
||||
// This method can be used to set weights or other parameters if needed.
|
||||
}
|
||||
private:
|
||||
void checkValues();
|
||||
bool validateTensors(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& sample_weights) const;
|
||||
// Tree hyperparameters
|
||||
int max_depth;
|
||||
int min_samples_split;
|
||||
int min_samples_leaf;
|
||||
int n_classes; // Number of classes in the target variable
|
||||
|
||||
// Root of the decision tree
|
||||
std::unique_ptr<TreeNode> root;
|
||||
|
||||
// Build tree recursively
|
||||
std::unique_ptr<TreeNode> buildTree(
|
||||
const torch::Tensor& X,
|
||||
const torch::Tensor& y,
|
||||
const torch::Tensor& sample_weights,
|
||||
int current_depth
|
||||
);
|
||||
|
||||
// Find best split for a node
|
||||
struct SplitInfo {
|
||||
int feature_index;
|
||||
int split_value;
|
||||
double impurity_decrease;
|
||||
torch::Tensor left_mask;
|
||||
torch::Tensor right_mask;
|
||||
};
|
||||
|
||||
SplitInfo findBestSplit(
|
||||
const torch::Tensor& X,
|
||||
const torch::Tensor& y,
|
||||
const torch::Tensor& sample_weights
|
||||
);
|
||||
|
||||
// Calculate weighted Gini impurity for multi-class
|
||||
double calculateGiniImpurity(
|
||||
const torch::Tensor& y,
|
||||
const torch::Tensor& sample_weights
|
||||
);
|
||||
|
||||
// Make predictions for a single sample
|
||||
int predictSample(const torch::Tensor& x) const;
|
||||
|
||||
// Make probabilistic predictions for a single sample
|
||||
torch::Tensor predictProbaSample(const torch::Tensor& x) const;
|
||||
|
||||
// Traverse tree to find leaf node
|
||||
const TreeNode* traverseTree(const torch::Tensor& x, const TreeNode* node) const;
|
||||
|
||||
// Convert tree to graph representation
|
||||
void treeToGraph(
|
||||
const TreeNode* node,
|
||||
std::vector<std::string>& lines,
|
||||
int& node_id,
|
||||
int parent_id = -1,
|
||||
const std::string& edge_label = ""
|
||||
) const;
|
||||
};
|
||||
|
||||
// Tree node structure
|
||||
struct TreeNode {
|
||||
bool is_leaf;
|
||||
|
||||
// For internal nodes
|
||||
int split_feature;
|
||||
int split_value;
|
||||
std::unique_ptr<TreeNode> left;
|
||||
std::unique_ptr<TreeNode> right;
|
||||
|
||||
// For leaf nodes
|
||||
int predicted_class;
|
||||
torch::Tensor class_probabilities; // Probability for each class
|
||||
|
||||
TreeNode() : is_leaf(false), split_feature(-1), split_value(-1), predicted_class(-1) {}
|
||||
};
|
||||
|
||||
} // namespace bayesnet
|
||||
|
||||
#endif // DECISION_TREE_H
|
142
src/experimental_clfs/README.md
Normal file
142
src/experimental_clfs/README.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# AdaBoost and DecisionTree Classifier Implementation
|
||||
|
||||
This implementation provides both a Decision Tree classifier and a multi-class AdaBoost classifier based on the SAMME (Stagewise Additive Modeling using a Multi-class Exponential loss) algorithm described in the paper "Multi-class AdaBoost" by Zhu et al. Implemented in C++ using <https://claude.ai>
|
||||
|
||||
## Components
|
||||
|
||||
### 1. DecisionTree Classifier
|
||||
|
||||
A classic decision tree implementation that:
|
||||
|
||||
- Supports multi-class classification
|
||||
- Handles weighted samples (essential for boosting)
|
||||
- Uses Gini impurity as the splitting criterion
|
||||
- Works with discrete/categorical features
|
||||
- Provides both class predictions and probability estimates
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Max Depth Control**: Limit tree depth to create weak learners
|
||||
- **Minimum Samples**: Control minimum samples for splitting and leaf nodes
|
||||
- **Weighted Training**: Properly handles sample weights for boosting
|
||||
- **Visualization**: Generates DOT format graphs of the tree structure
|
||||
|
||||
#### Hyperparameters
|
||||
|
||||
- `max_depth`: Maximum depth of the tree (default: 3)
|
||||
- `min_samples_split`: Minimum samples required to split a node (default: 2)
|
||||
- `min_samples_leaf`: Minimum samples required in a leaf node (default: 1)
|
||||
|
||||
### 2. AdaBoost Classifier
|
||||
|
||||
A multi-class AdaBoost implementation using DecisionTree as base estimators:
|
||||
|
||||
- **SAMME Algorithm**: Implements the multi-class extension of AdaBoost
|
||||
- **Automatic Stumps**: Uses decision stumps (max_depth=1) by default
|
||||
- **Early Stopping**: Stops if base classifier performs worse than random
|
||||
- **Ensemble Visualization**: Shows the weighted combination of base estimators
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Multi-class Support**: Natural extension to K classes
|
||||
- **Base Estimator Control**: Configure depth of base decision trees
|
||||
- **Training Monitoring**: Track training errors and estimator weights
|
||||
- **Probability Estimates**: Provides class probability predictions
|
||||
|
||||
#### Hyperparameters
|
||||
|
||||
- `n_estimators`: Number of base estimators to train (default: 50)
|
||||
- `base_max_depth`: Maximum depth for base decision trees (default: 1)
|
||||
|
||||
## Algorithm Details
|
||||
|
||||
The SAMME algorithm differs from binary AdaBoost in the calculation of the estimator weight (alpha):
|
||||
|
||||
```
|
||||
α = log((1 - err) / err) + log(K - 1)
|
||||
```
|
||||
|
||||
where `K` is the number of classes. This formula ensures that:
|
||||
|
||||
- When K = 2, it reduces to standard AdaBoost
|
||||
- For K > 2, base classifiers only need to be better than random guessing (1/K) rather than 50%
|
||||
|
||||
## Usage Example
|
||||
|
||||
```cpp
|
||||
// Create AdaBoost with decision stumps
|
||||
AdaBoost ada(100, 1); // 100 estimators, max_depth=1
|
||||
|
||||
// Train
|
||||
ada.fit(X_train, y_train, features, className, states, Smoothing_t::NONE);
|
||||
|
||||
// Predict
|
||||
auto predictions = ada.predict(X_test);
|
||||
auto probabilities = ada.predict_proba(X_test);
|
||||
|
||||
// Evaluate
|
||||
float accuracy = ada.score(X_test, y_test);
|
||||
|
||||
// Get ensemble information
|
||||
auto weights = ada.getEstimatorWeights();
|
||||
auto errors = ada.getTrainingErrors();
|
||||
```
|
||||
|
||||
## Implementation Structure
|
||||
|
||||
```
|
||||
AdaBoost (inherits from Ensemble)
|
||||
└── Uses multiple DecisionTree instances as base estimators
|
||||
└── DecisionTree (inherits from Classifier)
|
||||
└── Implements weighted Gini impurity splitting
|
||||
```
|
||||
|
||||
## Visualization
|
||||
|
||||
Both classifiers support graph visualization:
|
||||
|
||||
- **DecisionTree**: Shows the tree structure with split conditions
|
||||
- **AdaBoost**: Shows the ensemble of weighted base estimators
|
||||
|
||||
Generate visualizations using:
|
||||
|
||||
```cpp
|
||||
auto graph = classifier.graph("Title");
|
||||
```
|
||||
|
||||
## Data Format
|
||||
|
||||
Both classifiers expect discrete/categorical data:
|
||||
|
||||
- **Features**: Integer values representing categories (stored in `torch::Tensor` or `std::vector<std::vector<int>>`)
|
||||
- **Labels**: Integer values representing class indices (0, 1, ..., K-1)
|
||||
- **States**: Map defining possible values for each feature and the class variable
|
||||
- **Sample Weights**: Optional weights for each training sample (important for boosting)
|
||||
|
||||
Example data setup:
|
||||
|
||||
```cpp
|
||||
// Features matrix (n_features x n_samples)
|
||||
torch::Tensor X = torch::tensor({{0, 1, 2}, {1, 0, 1}}); // 2 features, 3 samples
|
||||
|
||||
// Labels vector
|
||||
torch::Tensor y = torch::tensor({0, 1, 0}); // 3 samples
|
||||
|
||||
// States definition
|
||||
std::map<std::string, std::vector<int>> states;
|
||||
states["feature1"] = {0, 1, 2}; // Feature 1 can take values 0, 1, or 2
|
||||
states["feature2"] = {0, 1}; // Feature 2 can take values 0 or 1
|
||||
states["class"] = {0, 1}; // Binary classification
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- The implementation handles discrete/categorical features as indicated by the int-based data structures
|
||||
- Sample weights are properly propagated through the tree building process
|
||||
- The DecisionTree implementation uses equality testing for splits (suitable for categorical data)
|
||||
- Both classifiers support the standard fit/predict interface from the base framework
|
||||
|
||||
## References
|
||||
|
||||
- Zhu, J., Zou, H., Rosset, S., & Hastie, T. (2009). Multi-class AdaBoost. Statistics and its interface, 2(3), 349-360.
|
||||
- Breiman, L., Friedman, J., Olshen, R., & Stone, C. (1984). Classification and Regression Trees. Wadsworth, Belmont, CA.
|
@@ -45,6 +45,19 @@ namespace platform {
|
||||
|
||||
return data;
|
||||
}
|
||||
static torch::Tensor to_matrix(const std::vector<std::vector<int>>& data)
|
||||
{
|
||||
if (data.empty()) return torch::empty({ 0, 0 }, torch::kInt64);
|
||||
size_t rows = data.size();
|
||||
size_t cols = data[0].size();
|
||||
torch::Tensor tensor = torch::empty({ static_cast<long>(rows), static_cast<long>(cols) }, torch::kInt64);
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
for (size_t j = 0; j < cols; ++j) {
|
||||
tensor.index_put_({ static_cast<long>(i), static_cast<long>(j) }, data[i][j]);
|
||||
}
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
@@ -23,10 +23,11 @@
|
||||
#include <pyclassifiers/ODTE.h>
|
||||
#include <pyclassifiers/SVC.h>
|
||||
#include <pyclassifiers/XGBoost.h>
|
||||
#include <pyclassifiers/AdaBoost.h>
|
||||
#include <pyclassifiers/AdaBoostPy.h>
|
||||
#include <pyclassifiers/RandomForest.h>
|
||||
#include "../experimental_clfs/XA1DE.h"
|
||||
#include "../experimental_clfs/AdaBoost.h"
|
||||
// #include "../experimental_clfs/AdaBoost.h"
|
||||
#include "../experimental_clfs/DecisionTree.h"
|
||||
|
||||
namespace platform {
|
||||
class Models {
|
||||
|
@@ -35,10 +35,12 @@ namespace platform {
|
||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();});
|
||||
static Registrar registrarXGB("XGBoost",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();});
|
||||
static Registrar registrarAda("AdaBoostPy",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoost();});
|
||||
// static Registrar registrarAda2("AdaBoost",
|
||||
// [](void) -> bayesnet::BaseClassifier* { return new platform::AdaBoost();});
|
||||
static Registrar registrarAdaPy("AdaBoostPy",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoostPy();});
|
||||
// static Registrar registrarAda("AdaBoost",
|
||||
// [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AdaBoost();});
|
||||
static Registrar registrarDT("DecisionTree",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::DecisionTree();});
|
||||
static Registrar registrarXSPODE("XSPODE",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpode(0);});
|
||||
static Registrar registrarXSP2DE("XSP2DE",
|
||||
|
@@ -12,11 +12,11 @@ if(ENABLE_TESTING)
|
||||
${Bayesnet_INCLUDE_DIRS}
|
||||
)
|
||||
set(TEST_SOURCES_PLATFORM
|
||||
TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp
|
||||
TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp TestDecisionTree.cpp
|
||||
${Platform_SOURCE_DIR}/src/common/Datasets.cpp ${Platform_SOURCE_DIR}/src/common/Dataset.cpp ${Platform_SOURCE_DIR}/src/common/Discretization.cpp
|
||||
${Platform_SOURCE_DIR}/src/main/Scores.cpp
|
||||
${Platform_SOURCE_DIR}/src/main/Scores.cpp ${Platform_SOURCE_DIR}/src/experimental_clfs/DecisionTree.cpp
|
||||
)
|
||||
add_executable(${TEST_PLATFORM} ${TEST_SOURCES_PLATFORM})
|
||||
target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" mdlp Catch2::Catch2WithMain BayesNet)
|
||||
target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" fimdlp Catch2::Catch2WithMain bayesnet)
|
||||
add_test(NAME ${TEST_PLATFORM} COMMAND ${TEST_PLATFORM})
|
||||
endif(ENABLE_TESTING)
|
||||
|
303
tests/TestDecisionTree.cpp
Normal file
303
tests/TestDecisionTree.cpp
Normal file
@@ -0,0 +1,303 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/matchers/catch_matchers_string.hpp>
|
||||
#include <catch2/matchers/catch_matchers_vector.hpp>
|
||||
#include <torch/torch.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include "experimental_clfs/DecisionTree.h"
|
||||
#include "TestUtils.h"
|
||||
|
||||
using namespace bayesnet;
|
||||
using namespace Catch::Matchers;
|
||||
|
||||
TEST_CASE("DecisionTree Construction", "[DecisionTree]")
|
||||
{
|
||||
SECTION("Default constructor")
|
||||
{
|
||||
REQUIRE_NOTHROW(DecisionTree());
|
||||
}
|
||||
|
||||
SECTION("Constructor with parameters")
|
||||
{
|
||||
REQUIRE_NOTHROW(DecisionTree(5, 10, 3));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]")
|
||||
{
|
||||
DecisionTree dt;
|
||||
|
||||
SECTION("Set individual hyperparameters")
|
||||
{
|
||||
REQUIRE_NOTHROW(dt.setMaxDepth(10));
|
||||
REQUIRE_NOTHROW(dt.setMinSamplesSplit(5));
|
||||
REQUIRE_NOTHROW(dt.setMinSamplesLeaf(2));
|
||||
}
|
||||
|
||||
SECTION("Set hyperparameters via JSON")
|
||||
{
|
||||
nlohmann::json params;
|
||||
params["max_depth"] = 7;
|
||||
params["min_samples_split"] = 4;
|
||||
params["min_samples_leaf"] = 2;
|
||||
|
||||
REQUIRE_NOTHROW(dt.setHyperparameters(params));
|
||||
}
|
||||
|
||||
SECTION("Invalid hyperparameters should throw")
|
||||
{
|
||||
nlohmann::json params;
|
||||
|
||||
// Negative max_depth
|
||||
params["max_depth"] = -1;
|
||||
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
|
||||
|
||||
// Zero min_samples_split
|
||||
params["max_depth"] = 5;
|
||||
params["min_samples_split"] = 0;
|
||||
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
|
||||
|
||||
// Negative min_samples_leaf
|
||||
params["min_samples_split"] = 2;
|
||||
params["min_samples_leaf"] = -5;
|
||||
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree Basic Functionality", "[DecisionTree]")
|
||||
{
|
||||
// Create a simple dataset
|
||||
int n_samples = 20;
|
||||
int n_features = 2;
|
||||
|
||||
std::vector<std::vector<int>> X(n_features, std::vector<int>(n_samples));
|
||||
std::vector<int> y(n_samples);
|
||||
|
||||
// Simple pattern: class depends on first feature
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
X[0][i] = i < 10 ? 0 : 1;
|
||||
X[1][i] = i % 2;
|
||||
y[i] = X[0][i]; // Class equals first feature
|
||||
}
|
||||
|
||||
std::vector<std::string> features = { "f1", "f2" };
|
||||
std::string className = "class";
|
||||
std::map<std::string, std::vector<int>> states;
|
||||
states["f1"] = { 0, 1 };
|
||||
states["f2"] = { 0, 1 };
|
||||
states["class"] = { 0, 1 };
|
||||
|
||||
SECTION("Training with vector interface")
|
||||
{
|
||||
DecisionTree dt(3, 2, 1);
|
||||
REQUIRE_NOTHROW(dt.fit(X, y, features, className, states, Smoothing_t::NONE));
|
||||
|
||||
auto predictions = dt.predict(X);
|
||||
REQUIRE(predictions.size() == static_cast<size_t>(n_samples));
|
||||
|
||||
// Should achieve perfect accuracy on this simple dataset
|
||||
int correct = 0;
|
||||
for (size_t i = 0; i < predictions.size(); i++) {
|
||||
if (predictions[i] == y[i]) correct++;
|
||||
}
|
||||
REQUIRE(correct == n_samples);
|
||||
}
|
||||
|
||||
SECTION("Prediction before fitting")
|
||||
{
|
||||
DecisionTree dt;
|
||||
REQUIRE_THROWS_WITH(dt.predict(X),
|
||||
ContainsSubstring("Classifier has not been fitted"));
|
||||
}
|
||||
|
||||
SECTION("Probability predictions")
|
||||
{
|
||||
DecisionTree dt(3, 2, 1);
|
||||
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
|
||||
|
||||
auto proba = dt.predict_proba(X);
|
||||
REQUIRE(proba.size() == static_cast<size_t>(n_samples));
|
||||
REQUIRE(proba[0].size() == 2); // Two classes
|
||||
|
||||
// Check probabilities sum to 1 and probabilities are valid
|
||||
auto predictions = dt.predict(X);
|
||||
for (size_t i = 0; i < proba.size(); i++) {
|
||||
auto p = proba[i];
|
||||
auto pred = predictions[i];
|
||||
REQUIRE(p.size() == 2);
|
||||
REQUIRE(p[0] >= 0.0);
|
||||
REQUIRE(p[1] >= 0.0);
|
||||
double sum = p[0] + p[1];
|
||||
//Check that prodict_proba matches the expected predict value
|
||||
REQUIRE(pred == (p[0] > p[1] ? 0 : 1));
|
||||
REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree on Iris Dataset", "[DecisionTree][iris]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
|
||||
SECTION("Training with dataset format")
|
||||
{
|
||||
DecisionTree dt(5, 2, 1);
|
||||
|
||||
INFO("Dataset shape: " << raw.dataset.sizes());
|
||||
INFO("Features: " << raw.featurest.size());
|
||||
INFO("Samples: " << raw.nSamples);
|
||||
|
||||
// DecisionTree expects dataset in format: features x samples, with labels as last row
|
||||
REQUIRE_NOTHROW(dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE));
|
||||
|
||||
// Test prediction
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
|
||||
// Calculate accuracy
|
||||
auto correct = torch::sum(predictions == raw.yt).item<int>();
|
||||
double accuracy = static_cast<double>(correct) / raw.yt.size(0);
|
||||
REQUIRE(accuracy > 0.97); // Reasonable accuracy for Iris
|
||||
}
|
||||
|
||||
SECTION("Training with vector interface")
|
||||
{
|
||||
DecisionTree dt(5, 2, 1);
|
||||
|
||||
REQUIRE_NOTHROW(dt.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv, Smoothing_t::NONE));
|
||||
|
||||
// std::cout << "Tree structure:\n";
|
||||
// auto graph_lines = dt.graph("Iris Decision Tree");
|
||||
// for (const auto& line : graph_lines) {
|
||||
// std::cout << line << "\n";
|
||||
// }
|
||||
auto predictions = dt.predict(raw.Xv);
|
||||
REQUIRE(predictions.size() == raw.yv.size());
|
||||
}
|
||||
|
||||
SECTION("Different tree depths")
|
||||
{
|
||||
std::vector<int> depths = { 1, 3, 5 };
|
||||
|
||||
for (int depth : depths) {
|
||||
DecisionTree dt(depth, 2, 1);
|
||||
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
|
||||
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree Edge Cases", "[DecisionTree]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
|
||||
SECTION("Very shallow tree")
|
||||
{
|
||||
DecisionTree dt(1, 2, 1); // depth = 1
|
||||
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
|
||||
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
|
||||
// With depth 1, should have at most 2 unique predictions
|
||||
auto unique_vals = at::_unique(predictions);
|
||||
REQUIRE(std::get<0>(unique_vals).size(0) <= 2);
|
||||
}
|
||||
|
||||
SECTION("High min_samples_split")
|
||||
{
|
||||
DecisionTree dt(10, 50, 1);
|
||||
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
|
||||
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree Graph Visualization", "[DecisionTree]")
|
||||
{
|
||||
// Simple dataset
|
||||
std::vector<std::vector<int>> X = { {0,0,0,1}, {0,1,1,1} }; // XOR pattern
|
||||
std::vector<int> y = { 0, 1, 1, 0 }; // XOR pattern
|
||||
std::vector<std::string> features = { "x1", "x2" };
|
||||
std::string className = "xor";
|
||||
std::map<std::string, std::vector<int>> states;
|
||||
states["x1"] = { 0, 1 };
|
||||
states["x2"] = { 0, 1 };
|
||||
states["xor"] = { 0, 1 };
|
||||
|
||||
SECTION("Graph generation")
|
||||
{
|
||||
DecisionTree dt(2, 1, 1);
|
||||
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
|
||||
|
||||
auto graph_lines = dt.graph();
|
||||
|
||||
REQUIRE(graph_lines.size() > 2);
|
||||
REQUIRE(graph_lines.front() == "digraph DecisionTree {");
|
||||
REQUIRE(graph_lines.back() == "}");
|
||||
|
||||
// Should contain node definitions
|
||||
bool has_nodes = false;
|
||||
for (const auto& line : graph_lines) {
|
||||
if (line.find("node") != std::string::npos) {
|
||||
has_nodes = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
REQUIRE(has_nodes);
|
||||
}
|
||||
|
||||
SECTION("Graph with title")
|
||||
{
|
||||
DecisionTree dt(2, 1, 1);
|
||||
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
|
||||
|
||||
auto graph_lines = dt.graph("XOR Tree");
|
||||
|
||||
bool has_title = false;
|
||||
for (const auto& line : graph_lines) {
|
||||
if (line.find("label=\"XOR Tree\"") != std::string::npos) {
|
||||
has_title = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
REQUIRE(has_title);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("DecisionTree with Weights", "[DecisionTree]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
|
||||
SECTION("Uniform weights")
|
||||
{
|
||||
DecisionTree dt(5, 2, 1);
|
||||
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, raw.weights, Smoothing_t::NONE);
|
||||
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
}
|
||||
|
||||
SECTION("Non-uniform weights")
|
||||
{
|
||||
auto weights = torch::ones({ raw.nSamples });
|
||||
weights.index({ torch::indexing::Slice(0, 50) }) *= 2.0; // Emphasize first class
|
||||
weights = weights / weights.sum();
|
||||
|
||||
DecisionTree dt(5, 2, 1);
|
||||
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, weights, Smoothing_t::NONE);
|
||||
|
||||
auto predictions = dt.predict(raw.Xt);
|
||||
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||
}
|
||||
}
|
@@ -7,7 +7,7 @@
|
||||
#include <string>
|
||||
#include "TestUtils.h"
|
||||
#include "folding.hpp"
|
||||
#include <ArffFiles.hpp>
|
||||
#include <ArffFiles/ArffFiles.hpp>
|
||||
#include <bayesnet/classifiers/TAN.h>
|
||||
#include "config_platform.h"
|
||||
|
||||
@@ -20,17 +20,17 @@ TEST_CASE("Test Platform version", "[Platform]")
|
||||
TEST_CASE("Test Folding library version", "[Folding]")
|
||||
{
|
||||
std::string version = folding::KFold(5, 100).version();
|
||||
REQUIRE(version == "1.1.0");
|
||||
REQUIRE(version == "1.1.1");
|
||||
}
|
||||
TEST_CASE("Test BayesNet version", "[BayesNet]")
|
||||
{
|
||||
std::string version = bayesnet::TAN().getVersion();
|
||||
REQUIRE(version == "1.0.6");
|
||||
REQUIRE(version == "1.1.2");
|
||||
}
|
||||
TEST_CASE("Test mdlp version", "[mdlp]")
|
||||
{
|
||||
std::string version = mdlp::CPPFImdlp::version();
|
||||
REQUIRE(version == "2.0.0");
|
||||
REQUIRE(version == "2.0.1");
|
||||
}
|
||||
TEST_CASE("Test Arff version", "[Arff]")
|
||||
{
|
||||
|
@@ -14,38 +14,40 @@
|
||||
using json = nlohmann::ordered_json;
|
||||
auto epsilon = 1e-4;
|
||||
|
||||
void make_test_bin(int TP, int TN, int FP, int FN, std::vector<int>& y_test, std::vector<int>& y_pred)
|
||||
void make_test_bin(int TP, int TN, int FP, int FN, std::vector<int>& y_test, torch::Tensor& y_pred)
|
||||
{
|
||||
// TP
|
||||
std::vector<std::array<double, 2>> probs;
|
||||
// TP: true positive (label 1, predicted 1)
|
||||
for (int i = 0; i < TP; i++) {
|
||||
y_test.push_back(1);
|
||||
y_pred.push_back(1);
|
||||
probs.push_back({ 0.0, 1.0 }); // P(class 0)=0, P(class 1)=1
|
||||
}
|
||||
// TN
|
||||
// TN: true negative (label 0, predicted 0)
|
||||
for (int i = 0; i < TN; i++) {
|
||||
y_test.push_back(0);
|
||||
y_pred.push_back(0);
|
||||
probs.push_back({ 1.0, 0.0 }); // P(class 0)=1, P(class 1)=0
|
||||
}
|
||||
// FP
|
||||
// FP: false positive (label 0, predicted 1)
|
||||
for (int i = 0; i < FP; i++) {
|
||||
y_test.push_back(0);
|
||||
y_pred.push_back(1);
|
||||
probs.push_back({ 0.0, 1.0 }); // P(class 0)=0, P(class 1)=1
|
||||
}
|
||||
// FN
|
||||
// FN: false negative (label 1, predicted 0)
|
||||
for (int i = 0; i < FN; i++) {
|
||||
y_test.push_back(1);
|
||||
y_pred.push_back(0);
|
||||
probs.push_back({ 1.0, 0.0 }); // P(class 0)=1, P(class 1)=0
|
||||
}
|
||||
// Convert to torch::Tensor of double, shape [N,2]
|
||||
y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 2 }, torch::kFloat64).clone();
|
||||
}
|
||||
|
||||
TEST_CASE("Scores binary", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test;
|
||||
std::vector<int> y_pred;
|
||||
torch::Tensor y_pred;
|
||||
make_test_bin(197, 210, 52, 41, y_test, y_pred);
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 2);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 2);
|
||||
REQUIRE(scores.accuracy() == Catch::Approx(0.814).epsilon(epsilon));
|
||||
REQUIRE(scores.f1_score(0) == Catch::Approx(0.818713));
|
||||
REQUIRE(scores.f1_score(1) == Catch::Approx(0.809035));
|
||||
@@ -64,10 +66,23 @@ TEST_CASE("Scores binary", "[Scores]")
|
||||
TEST_CASE("Scores multiclass", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
// Refactor y_pred to a tensor of shape [10, 3] with probabilities
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
// Convert y_test to a tensor
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3);
|
||||
REQUIRE(scores.accuracy() == Catch::Approx(0.6).epsilon(epsilon));
|
||||
REQUIRE(scores.f1_score(0) == Catch::Approx(0.666667));
|
||||
REQUIRE(scores.f1_score(1) == Catch::Approx(0.4));
|
||||
@@ -84,10 +99,21 @@ TEST_CASE("Scores multiclass", "[Scores]")
|
||||
TEST_CASE("Test Confusion Matrix Values", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3);
|
||||
auto confusion_matrix = scores.get_confusion_matrix();
|
||||
REQUIRE(confusion_matrix[0][0].item<int>() == 2);
|
||||
REQUIRE(confusion_matrix[0][1].item<int>() == 1);
|
||||
@@ -102,11 +128,22 @@ TEST_CASE("Test Confusion Matrix Values", "[Scores]")
|
||||
TEST_CASE("Confusion Matrix JSON", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||
auto res_json_int = scores.get_confusion_matrix_json();
|
||||
REQUIRE(res_json_int[0][0] == 2);
|
||||
REQUIRE(res_json_int[0][1] == 1);
|
||||
@@ -131,11 +168,22 @@ TEST_CASE("Confusion Matrix JSON", "[Scores]")
|
||||
TEST_CASE("Classification Report", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||
auto report = scores.classification_report(Colors::BLUE(), "train");
|
||||
auto json_matrix = scores.get_confusion_matrix_json(true);
|
||||
platform::Scores scores2(json_matrix);
|
||||
@@ -144,11 +192,22 @@ TEST_CASE("Classification Report", "[Scores]")
|
||||
TEST_CASE("JSON constructor", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Car", "Boat", "Aeroplane" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||
auto res_json_int = scores.get_confusion_matrix_json();
|
||||
platform::Scores scores2(res_json_int);
|
||||
REQUIRE(scores.accuracy() == scores2.accuracy());
|
||||
@@ -173,17 +232,14 @@ TEST_CASE("JSON constructor", "[Scores]")
|
||||
TEST_CASE("Aggregate", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test;
|
||||
std::vector<int> y_pred;
|
||||
torch::Tensor y_pred;
|
||||
make_test_bin(197, 210, 52, 41, y_test, y_pred);
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 2);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 2);
|
||||
y_test.clear();
|
||||
y_pred.clear();
|
||||
make_test_bin(227, 187, 39, 47, y_test, y_pred);
|
||||
auto y_test_tensor2 = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor2 = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores2(y_test_tensor2, y_pred_tensor2, 2);
|
||||
platform::Scores scores2(y_test_tensor2, y_pred, 2);
|
||||
scores.aggregate(scores2);
|
||||
REQUIRE(scores.accuracy() == Catch::Approx(0.821).epsilon(epsilon));
|
||||
REQUIRE(scores.f1_score(0) == Catch::Approx(0.8160329));
|
||||
@@ -195,11 +251,9 @@ TEST_CASE("Aggregate", "[Scores]")
|
||||
REQUIRE(scores.f1_weighted() == Catch::Approx(0.8209856));
|
||||
REQUIRE(scores.f1_macro() == Catch::Approx(0.8208694));
|
||||
y_test.clear();
|
||||
y_pred.clear();
|
||||
make_test_bin(197 + 227, 210 + 187, 52 + 39, 41 + 47, y_test, y_pred);
|
||||
y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
platform::Scores scores3(y_test_tensor, y_pred_tensor, 2);
|
||||
platform::Scores scores3(y_test_tensor, y_pred, 2);
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
REQUIRE(scores3.f1_score(i) == scores.f1_score(i));
|
||||
REQUIRE(scores3.precision(i) == scores.precision(i));
|
||||
@@ -212,11 +266,22 @@ TEST_CASE("Aggregate", "[Scores]")
|
||||
TEST_CASE("Order of keys", "[Scores]")
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
std::vector<std::array<double, 3>> probs = {
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||
};
|
||||
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Car", "Boat", "Aeroplane" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||
auto res_json_int = scores.get_confusion_matrix_json(true);
|
||||
// Make a temp file and store the json
|
||||
std::string filename = "temp.json";
|
||||
|
@@ -5,7 +5,7 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <tuple>
|
||||
#include <ArffFiles.hpp>
|
||||
#include <ArffFiles/ArffFiles.hpp>
|
||||
#include <fimdlp/CPPFImdlp.h>
|
||||
|
||||
bool file_exists(const std::string& name);
|
||||
|
@@ -7,6 +7,7 @@
|
||||
"fimdlp",
|
||||
"libtorch-bin",
|
||||
"folding",
|
||||
"catch2",
|
||||
"argparse"
|
||||
],
|
||||
"overrides": [
|
||||
@@ -30,9 +31,13 @@
|
||||
"name": "argpase",
|
||||
"version": "3.2"
|
||||
},
|
||||
{
|
||||
"name": "catch2",
|
||||
"version": "3.8.1"
|
||||
},
|
||||
{
|
||||
"name": "nlohmann-json",
|
||||
"version": "3.11.3"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user