Add DecisionTree with tests
This commit is contained in:
1
Makefile
1
Makefile
@@ -96,6 +96,7 @@ opt = ""
|
|||||||
test: ## Run tests (opt="-s") to verbose output the tests, (opt="-c='Test Maximum Spanning Tree'") to run only that section
|
test: ## Run tests (opt="-s") to verbose output the tests, (opt="-c='Test Maximum Spanning Tree'") to run only that section
|
||||||
@echo ">>> Running Platform tests...";
|
@echo ">>> Running Platform tests...";
|
||||||
@$(MAKE) clean
|
@$(MAKE) clean
|
||||||
|
@$(MAKE) debug
|
||||||
@cmake --build $(f_debug) -t $(test_targets) --parallel
|
@cmake --build $(f_debug) -t $(test_targets) --parallel
|
||||||
@for t in $(test_targets); do \
|
@for t in $(test_targets); do \
|
||||||
if [ -f $(f_debug)/tests/$$t ]; then \
|
if [ -f $(f_debug)/tests/$$t ]; then \
|
||||||
|
@@ -20,6 +20,8 @@ add_executable(
|
|||||||
results/Result.cpp
|
results/Result.cpp
|
||||||
experimental_clfs/XA1DE.cpp
|
experimental_clfs/XA1DE.cpp
|
||||||
experimental_clfs/ExpClf.cpp
|
experimental_clfs/ExpClf.cpp
|
||||||
|
experimental_clfs/DecisionTree.cpp
|
||||||
|
|
||||||
)
|
)
|
||||||
target_link_libraries(b_best Boost::boost "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}")
|
target_link_libraries(b_best Boost::boost "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}")
|
||||||
|
|
||||||
@@ -33,7 +35,7 @@ add_executable(b_grid commands/b_grid.cpp ${grid_sources}
|
|||||||
results/Result.cpp
|
results/Result.cpp
|
||||||
experimental_clfs/XA1DE.cpp
|
experimental_clfs/XA1DE.cpp
|
||||||
experimental_clfs/ExpClf.cpp
|
experimental_clfs/ExpClf.cpp
|
||||||
experimental_clfs/AdaBoost.cpp
|
experimental_clfs/DecisionTree.cpp
|
||||||
)
|
)
|
||||||
target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy)
|
target_link_libraries(b_grid ${MPI_CXX_LIBRARIES} "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy)
|
||||||
|
|
||||||
@@ -45,6 +47,8 @@ add_executable(b_list commands/b_list.cpp
|
|||||||
results/Result.cpp results/ResultsDatasetExcel.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp
|
results/Result.cpp results/ResultsDatasetExcel.cpp results/ResultsDataset.cpp results/ResultsDatasetConsole.cpp
|
||||||
experimental_clfs/XA1DE.cpp
|
experimental_clfs/XA1DE.cpp
|
||||||
experimental_clfs/ExpClf.cpp
|
experimental_clfs/ExpClf.cpp
|
||||||
|
experimental_clfs/DecisionTree.cpp
|
||||||
|
|
||||||
)
|
)
|
||||||
target_link_libraries(b_list "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}")
|
target_link_libraries(b_list "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy "${XLSXWRITER_LIB}")
|
||||||
|
|
||||||
@@ -58,6 +62,8 @@ add_executable(b_main commands/b_main.cpp ${main_sources}
|
|||||||
experimental_clfs/XA1DE.cpp
|
experimental_clfs/XA1DE.cpp
|
||||||
experimental_clfs/ExpClf.cpp
|
experimental_clfs/ExpClf.cpp
|
||||||
experimental_clfs/ExpClf.cpp
|
experimental_clfs/ExpClf.cpp
|
||||||
|
experimental_clfs/DecisionTree.cpp
|
||||||
|
|
||||||
)
|
)
|
||||||
target_link_libraries(b_main PRIVATE nlohmann_json::nlohmann_json "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy)
|
target_link_libraries(b_main PRIVATE nlohmann_json::nlohmann_json "${PyClassifiers}" bayesnet::bayesnet fimdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" Boost::python Boost::numpy)
|
||||||
|
|
||||||
|
@@ -5,18 +5,19 @@
|
|||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
|
|
||||||
#include "AdaBoost.h"
|
#include "AdaBoost.h"
|
||||||
|
#include "DecisionTree.h"
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
||||||
namespace platform {
|
namespace bayesnet {
|
||||||
|
|
||||||
AdaBoost::AdaBoost(int n_estimators)
|
AdaBoost::AdaBoost(int n_estimators, int max_depth)
|
||||||
: Ensemble(true), n_estimators(n_estimators)
|
: Ensemble(true), n_estimators(n_estimators), base_max_depth(max_depth)
|
||||||
{
|
{
|
||||||
validHyperparameters = { "n_estimators" };
|
validHyperparameters = { "n_estimators", "base_max_depth" };
|
||||||
}
|
}
|
||||||
|
|
||||||
void AdaBoost::buildModel(const torch::Tensor& weights)
|
void AdaBoost::buildModel(const torch::Tensor& weights)
|
||||||
@@ -89,20 +90,14 @@ namespace platform {
|
|||||||
|
|
||||||
std::unique_ptr<Classifier> AdaBoost::trainBaseEstimator(const torch::Tensor& weights)
|
std::unique_ptr<Classifier> AdaBoost::trainBaseEstimator(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
// Create a new classifier instance
|
// Create a decision tree with specified max depth
|
||||||
// You need to implement this based on your specific base classifier
|
// For AdaBoost, we typically use shallow trees (stumps with max_depth=1)
|
||||||
// For example, if using Decision Trees:
|
auto tree = std::make_unique<DecisionTree>(base_max_depth);
|
||||||
// auto classifier = std::make_unique<DecisionTree>();
|
|
||||||
|
|
||||||
// Or if using a factory method:
|
// Fit the tree with the current sample weights
|
||||||
// auto classifier = ClassifierFactory::create("DecisionTree");
|
tree->fit(dataset, features, className, states, weights, Smoothing_t::NONE);
|
||||||
|
|
||||||
// Placeholder - replace with actual classifier creation
|
return tree;
|
||||||
throw std::runtime_error("AdaBoost::trainBaseEstimator - You need to implement base classifier creation");
|
|
||||||
|
|
||||||
// Once you have the classifier creation implemented, uncomment:
|
|
||||||
// classifier->fit(dataset, features, className, states, weights, Smoothing_t::NONE);
|
|
||||||
// return classifier;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
double AdaBoost::calculateWeightedError(Classifier* estimator, const torch::Tensor& weights)
|
double AdaBoost::calculateWeightedError(Classifier* estimator, const torch::Tensor& weights)
|
||||||
@@ -192,8 +187,9 @@ namespace platform {
|
|||||||
return graph_lines;
|
return graph_lines;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters)
|
void AdaBoost::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||||
{
|
{
|
||||||
|
auto hyperparameters = hyperparameters_;
|
||||||
// Set hyperparameters from JSON
|
// Set hyperparameters from JSON
|
||||||
auto it = hyperparameters.find("n_estimators");
|
auto it = hyperparameters.find("n_estimators");
|
||||||
if (it != hyperparameters.end()) {
|
if (it != hyperparameters.end()) {
|
||||||
@@ -201,14 +197,18 @@ namespace platform {
|
|||||||
if (n_estimators <= 0) {
|
if (n_estimators <= 0) {
|
||||||
throw std::invalid_argument("n_estimators must be positive");
|
throw std::invalid_argument("n_estimators must be positive");
|
||||||
}
|
}
|
||||||
|
hyperparameters.erase("n_estimators"); // Remove 'n_estimators' if present
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for invalid hyperparameters
|
it = hyperparameters.find("base_max_depth");
|
||||||
for (auto& [key, value] : hyperparameters.items()) {
|
if (it != hyperparameters.end()) {
|
||||||
if (std::find(validHyperparameters.begin(), validHyperparameters.end(), key) == validHyperparameters.end()) {
|
base_max_depth = it->get<int>();
|
||||||
throw std::invalid_argument("Invalid hyperparameter: " + key);
|
if (base_max_depth <= 0) {
|
||||||
|
throw std::invalid_argument("base_max_depth must be positive");
|
||||||
}
|
}
|
||||||
|
hyperparameters.erase("base_max_depth"); // Remove 'base_max_depth' if present
|
||||||
}
|
}
|
||||||
|
Ensemble::setHyperparameters(hyperparameters);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace bayesnet
|
} // namespace bayesnet
|
@@ -9,13 +9,12 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <torch/torch.h>
|
#include "bayesnet/ensembles/Ensemble.h"
|
||||||
#include <bayesnet/ensembles/Ensemble.h>
|
|
||||||
|
|
||||||
namespace platform {
|
namespace bayesnet {
|
||||||
class AdaBoost : public bayesnet::Ensemble {
|
class AdaBoost : public Ensemble {
|
||||||
public:
|
public:
|
||||||
explicit AdaBoost(int n_estimators = 100);
|
explicit AdaBoost(int n_estimators = 50, int max_depth = 1);
|
||||||
virtual ~AdaBoost() = default;
|
virtual ~AdaBoost() = default;
|
||||||
|
|
||||||
// Override base class methods
|
// Override base class methods
|
||||||
@@ -24,10 +23,15 @@ namespace platform {
|
|||||||
// AdaBoost specific methods
|
// AdaBoost specific methods
|
||||||
void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; }
|
void setNEstimators(int n_estimators) { this->n_estimators = n_estimators; }
|
||||||
int getNEstimators() const { return n_estimators; }
|
int getNEstimators() const { return n_estimators; }
|
||||||
|
void setBaseMaxDepth(int depth) { this->base_max_depth = depth; }
|
||||||
|
int getBaseMaxDepth() const { return base_max_depth; }
|
||||||
|
|
||||||
// Get the weight of each base estimator
|
// Get the weight of each base estimator
|
||||||
std::vector<double> getEstimatorWeights() const { return alphas; }
|
std::vector<double> getEstimatorWeights() const { return alphas; }
|
||||||
|
|
||||||
|
// Get training errors for each iteration
|
||||||
|
std::vector<double> getTrainingErrors() const { return training_errors; }
|
||||||
|
|
||||||
// Override setHyperparameters from BaseClassifier
|
// Override setHyperparameters from BaseClassifier
|
||||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||||
|
|
||||||
@@ -37,6 +41,7 @@ namespace platform {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
int n_estimators;
|
int n_estimators;
|
||||||
|
int base_max_depth; // Max depth for base decision trees
|
||||||
std::vector<double> alphas; // Weight of each base estimator
|
std::vector<double> alphas; // Weight of each base estimator
|
||||||
std::vector<double> training_errors; // Training error at each iteration
|
std::vector<double> training_errors; // Training error at each iteration
|
||||||
torch::Tensor sample_weights; // Current sample weights
|
torch::Tensor sample_weights; // Current sample weights
|
||||||
|
519
src/experimental_clfs/DecisionTree.cpp
Normal file
519
src/experimental_clfs/DecisionTree.cpp
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
// ***************************************************************
|
||||||
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||||
|
// SPDX-FileType: SOURCE
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// ***************************************************************
|
||||||
|
|
||||||
|
#include "DecisionTree.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <numeric>
|
||||||
|
#include <sstream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <limits>
|
||||||
|
#include "TensorUtils.hpp"
|
||||||
|
|
||||||
|
namespace bayesnet {
|
||||||
|
|
||||||
|
DecisionTree::DecisionTree(int max_depth, int min_samples_split, int min_samples_leaf)
|
||||||
|
: Classifier(Network()), max_depth(max_depth),
|
||||||
|
min_samples_split(min_samples_split), min_samples_leaf(min_samples_leaf)
|
||||||
|
{
|
||||||
|
validHyperparameters = { "max_depth", "min_samples_split", "min_samples_leaf" };
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecisionTree::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||||
|
{
|
||||||
|
auto hyperparameters = hyperparameters_;
|
||||||
|
// Set hyperparameters from JSON
|
||||||
|
auto it = hyperparameters.find("max_depth");
|
||||||
|
if (it != hyperparameters.end()) {
|
||||||
|
max_depth = it->get<int>();
|
||||||
|
hyperparameters.erase("max_depth"); // Remove 'order' if present
|
||||||
|
}
|
||||||
|
|
||||||
|
it = hyperparameters.find("min_samples_split");
|
||||||
|
if (it != hyperparameters.end()) {
|
||||||
|
min_samples_split = it->get<int>();
|
||||||
|
hyperparameters.erase("min_samples_split"); // Remove 'min_samples_split' if present
|
||||||
|
}
|
||||||
|
|
||||||
|
it = hyperparameters.find("min_samples_leaf");
|
||||||
|
if (it != hyperparameters.end()) {
|
||||||
|
min_samples_leaf = it->get<int>();
|
||||||
|
hyperparameters.erase("min_samples_leaf"); // Remove 'min_samples_leaf' if present
|
||||||
|
}
|
||||||
|
Classifier::setHyperparameters(hyperparameters);
|
||||||
|
checkValues();
|
||||||
|
}
|
||||||
|
void DecisionTree::checkValues()
|
||||||
|
{
|
||||||
|
if (max_depth <= 0) {
|
||||||
|
throw std::invalid_argument("max_depth must be positive");
|
||||||
|
}
|
||||||
|
if (min_samples_leaf <= 0) {
|
||||||
|
throw std::invalid_argument("min_samples_leaf must be positive");
|
||||||
|
}
|
||||||
|
if (min_samples_split <= 0) {
|
||||||
|
throw std::invalid_argument("min_samples_split must be positive");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
void DecisionTree::buildModel(const torch::Tensor& weights)
|
||||||
|
{
|
||||||
|
// Extract features (X) and labels (y) from dataset
|
||||||
|
auto X = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), torch::indexing::Slice() }).t();
|
||||||
|
auto y = dataset.index({ -1, torch::indexing::Slice() });
|
||||||
|
|
||||||
|
if (X.size(0) != y.size(0)) {
|
||||||
|
throw std::runtime_error("X and y must have the same number of samples");
|
||||||
|
}
|
||||||
|
|
||||||
|
n_classes = states[className].size();
|
||||||
|
|
||||||
|
// Use provided weights or uniform weights
|
||||||
|
torch::Tensor sample_weights;
|
||||||
|
if (weights.defined() && weights.numel() > 0) {
|
||||||
|
if (weights.size(0) != X.size(0)) {
|
||||||
|
throw std::runtime_error("weights must have the same length as number of samples");
|
||||||
|
}
|
||||||
|
sample_weights = weights;
|
||||||
|
} else {
|
||||||
|
sample_weights = torch::ones({ X.size(0) }) / X.size(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize weights
|
||||||
|
sample_weights = sample_weights / sample_weights.sum();
|
||||||
|
|
||||||
|
// Build the tree
|
||||||
|
root = buildTree(X, y, sample_weights, 0);
|
||||||
|
|
||||||
|
// Mark as fitted
|
||||||
|
fitted = true;
|
||||||
|
}
|
||||||
|
bool DecisionTree::validateTensors(const torch::Tensor& X, const torch::Tensor& y,
|
||||||
|
const torch::Tensor& sample_weights) const
|
||||||
|
{
|
||||||
|
if (X.size(0) != y.size(0) || X.size(0) != sample_weights.size(0)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (X.size(0) == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<TreeNode> DecisionTree::buildTree(
|
||||||
|
const torch::Tensor& X,
|
||||||
|
const torch::Tensor& y,
|
||||||
|
const torch::Tensor& sample_weights,
|
||||||
|
int current_depth)
|
||||||
|
{
|
||||||
|
auto node = std::make_unique<TreeNode>();
|
||||||
|
int n_samples = y.size(0);
|
||||||
|
|
||||||
|
// Check stopping criteria
|
||||||
|
auto unique = at::_unique(y);
|
||||||
|
bool should_stop = (current_depth >= max_depth) ||
|
||||||
|
(n_samples < min_samples_split) ||
|
||||||
|
(std::get<0>(unique).size(0) == 1); // All samples same class
|
||||||
|
|
||||||
|
if (should_stop || n_samples <= min_samples_leaf) {
|
||||||
|
// Create leaf node
|
||||||
|
node->is_leaf = true;
|
||||||
|
|
||||||
|
// Calculate class probabilities
|
||||||
|
node->class_probabilities = torch::zeros({ n_classes });
|
||||||
|
|
||||||
|
for (int i = 0; i < n_samples; i++) {
|
||||||
|
int class_idx = y[i].item<int>();
|
||||||
|
node->class_probabilities[class_idx] += sample_weights[i].item<float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize probabilities
|
||||||
|
node->class_probabilities /= node->class_probabilities.sum();
|
||||||
|
|
||||||
|
// Set predicted class as the one with highest probability
|
||||||
|
node->predicted_class = torch::argmax(node->class_probabilities).item<int>();
|
||||||
|
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find best split
|
||||||
|
SplitInfo best_split = findBestSplit(X, y, sample_weights);
|
||||||
|
|
||||||
|
// If no valid split found, create leaf
|
||||||
|
if (best_split.feature_index == -1 || best_split.impurity_decrease <= 0) {
|
||||||
|
node->is_leaf = true;
|
||||||
|
|
||||||
|
// Calculate class probabilities
|
||||||
|
node->class_probabilities = torch::zeros({ n_classes });
|
||||||
|
|
||||||
|
for (int i = 0; i < n_samples; i++) {
|
||||||
|
int class_idx = y[i].item<int>();
|
||||||
|
node->class_probabilities[class_idx] += sample_weights[i].item<float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
node->class_probabilities /= node->class_probabilities.sum();
|
||||||
|
node->predicted_class = torch::argmax(node->class_probabilities).item<int>();
|
||||||
|
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create internal node
|
||||||
|
node->is_leaf = false;
|
||||||
|
node->split_feature = best_split.feature_index;
|
||||||
|
node->split_value = best_split.split_value;
|
||||||
|
|
||||||
|
// Split data
|
||||||
|
auto left_X = X.index({ best_split.left_mask });
|
||||||
|
auto left_y = y.index({ best_split.left_mask });
|
||||||
|
auto left_weights = sample_weights.index({ best_split.left_mask });
|
||||||
|
|
||||||
|
auto right_X = X.index({ best_split.right_mask });
|
||||||
|
auto right_y = y.index({ best_split.right_mask });
|
||||||
|
auto right_weights = sample_weights.index({ best_split.right_mask });
|
||||||
|
|
||||||
|
// Recursively build subtrees
|
||||||
|
if (left_X.size(0) >= min_samples_leaf) {
|
||||||
|
node->left = buildTree(left_X, left_y, left_weights, current_depth + 1);
|
||||||
|
} else {
|
||||||
|
// Force leaf if not enough samples
|
||||||
|
node->left = std::make_unique<TreeNode>();
|
||||||
|
node->left->is_leaf = true;
|
||||||
|
auto mode = std::get<0>(torch::mode(left_y));
|
||||||
|
node->left->predicted_class = mode.item<int>();
|
||||||
|
node->left->class_probabilities = torch::zeros({ n_classes });
|
||||||
|
node->left->class_probabilities[node->left->predicted_class] = 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (right_X.size(0) >= min_samples_leaf) {
|
||||||
|
node->right = buildTree(right_X, right_y, right_weights, current_depth + 1);
|
||||||
|
} else {
|
||||||
|
// Force leaf if not enough samples
|
||||||
|
node->right = std::make_unique<TreeNode>();
|
||||||
|
node->right->is_leaf = true;
|
||||||
|
auto mode = std::get<0>(torch::mode(right_y));
|
||||||
|
node->right->predicted_class = mode.item<int>();
|
||||||
|
node->right->class_probabilities = torch::zeros({ n_classes });
|
||||||
|
node->right->class_probabilities[node->right->predicted_class] = 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
DecisionTree::SplitInfo DecisionTree::findBestSplit(
|
||||||
|
const torch::Tensor& X,
|
||||||
|
const torch::Tensor& y,
|
||||||
|
const torch::Tensor& sample_weights)
|
||||||
|
{
|
||||||
|
|
||||||
|
SplitInfo best_split;
|
||||||
|
best_split.feature_index = -1;
|
||||||
|
best_split.split_value = -1;
|
||||||
|
best_split.impurity_decrease = -std::numeric_limits<double>::infinity();
|
||||||
|
|
||||||
|
int n_features = X.size(1);
|
||||||
|
int n_samples = X.size(0);
|
||||||
|
|
||||||
|
// Calculate impurity of current node
|
||||||
|
double current_impurity = calculateGiniImpurity(y, sample_weights);
|
||||||
|
double total_weight = sample_weights.sum().item<double>();
|
||||||
|
|
||||||
|
// Try each feature
|
||||||
|
for (int feat_idx = 0; feat_idx < n_features; feat_idx++) {
|
||||||
|
auto feature_values = X.index({ torch::indexing::Slice(), feat_idx });
|
||||||
|
auto unique_values = std::get<0>(torch::unique_consecutive(std::get<0>(torch::sort(feature_values))));
|
||||||
|
|
||||||
|
// Try each unique value as split point
|
||||||
|
for (int i = 0; i < unique_values.size(0); i++) {
|
||||||
|
int split_val = unique_values[i].item<int>();
|
||||||
|
|
||||||
|
// Create masks for left and right splits
|
||||||
|
auto left_mask = feature_values == split_val;
|
||||||
|
auto right_mask = ~left_mask;
|
||||||
|
|
||||||
|
int left_count = left_mask.sum().item<int>();
|
||||||
|
int right_count = right_mask.sum().item<int>();
|
||||||
|
|
||||||
|
// Skip if split doesn't satisfy minimum samples requirement
|
||||||
|
if (left_count < min_samples_leaf || right_count < min_samples_leaf) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate weighted impurities
|
||||||
|
auto left_y = y.index({ left_mask });
|
||||||
|
auto left_weights = sample_weights.index({ left_mask });
|
||||||
|
double left_weight = left_weights.sum().item<double>();
|
||||||
|
double left_impurity = calculateGiniImpurity(left_y, left_weights);
|
||||||
|
|
||||||
|
auto right_y = y.index({ right_mask });
|
||||||
|
auto right_weights = sample_weights.index({ right_mask });
|
||||||
|
double right_weight = right_weights.sum().item<double>();
|
||||||
|
double right_impurity = calculateGiniImpurity(right_y, right_weights);
|
||||||
|
|
||||||
|
// Calculate impurity decrease
|
||||||
|
double impurity_decrease = current_impurity -
|
||||||
|
(left_weight / total_weight * left_impurity +
|
||||||
|
right_weight / total_weight * right_impurity);
|
||||||
|
|
||||||
|
// Update best split if this is better
|
||||||
|
if (impurity_decrease > best_split.impurity_decrease) {
|
||||||
|
best_split.feature_index = feat_idx;
|
||||||
|
best_split.split_value = split_val;
|
||||||
|
best_split.impurity_decrease = impurity_decrease;
|
||||||
|
best_split.left_mask = left_mask;
|
||||||
|
best_split.right_mask = right_mask;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return best_split;
|
||||||
|
}
|
||||||
|
|
||||||
|
double DecisionTree::calculateGiniImpurity(
|
||||||
|
const torch::Tensor& y,
|
||||||
|
const torch::Tensor& sample_weights)
|
||||||
|
{
|
||||||
|
if (y.size(0) == 0 || sample_weights.size(0) == 0) {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (y.size(0) != sample_weights.size(0)) {
|
||||||
|
throw std::runtime_error("y and sample_weights must have same size");
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor class_weights = torch::zeros({ n_classes });
|
||||||
|
|
||||||
|
// Calculate weighted class counts
|
||||||
|
for (int i = 0; i < y.size(0); i++) {
|
||||||
|
int class_idx = y[i].item<int>();
|
||||||
|
|
||||||
|
if (class_idx < 0 || class_idx >= n_classes) {
|
||||||
|
throw std::runtime_error("Invalid class index: " + std::to_string(class_idx));
|
||||||
|
}
|
||||||
|
|
||||||
|
class_weights[class_idx] += sample_weights[i].item<float>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize
|
||||||
|
double total_weight = class_weights.sum().item<double>();
|
||||||
|
if (total_weight == 0) return 0.0;
|
||||||
|
|
||||||
|
class_weights /= total_weight;
|
||||||
|
|
||||||
|
// Calculate Gini impurity: 1 - sum(p_i^2)
|
||||||
|
double gini = 1.0;
|
||||||
|
for (int i = 0; i < n_classes; i++) {
|
||||||
|
double p = class_weights[i].item<double>();
|
||||||
|
gini -= p * p;
|
||||||
|
}
|
||||||
|
|
||||||
|
return gini;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
torch::Tensor DecisionTree::predict(torch::Tensor& X)
|
||||||
|
{
|
||||||
|
if (!fitted) {
|
||||||
|
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_samples = X.size(1);
|
||||||
|
torch::Tensor predictions = torch::zeros({ n_samples }, torch::kInt32);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_samples; i++) {
|
||||||
|
auto sample = X.index({ torch::indexing::Slice(), i }).ravel();
|
||||||
|
predictions[i] = predictSample(sample);
|
||||||
|
}
|
||||||
|
|
||||||
|
return predictions;
|
||||||
|
}
|
||||||
|
void dumpTensor(const torch::Tensor& tensor, const std::string& name)
|
||||||
|
{
|
||||||
|
std::cout << name << ": " << std::endl;
|
||||||
|
for (int i = 0; i < tensor.size(0); i++) {
|
||||||
|
std::cout << "[";
|
||||||
|
for (int j = 0; j < tensor.size(1); j++) {
|
||||||
|
std::cout << tensor[i][j].item<int>() << " ";
|
||||||
|
}
|
||||||
|
std::cout << "]" << std::endl;
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
void dumpVector(const std::vector<std::vector<int>>& vec, const std::string& name)
|
||||||
|
{
|
||||||
|
std::cout << name << ": " << std::endl;;
|
||||||
|
for (const auto& row : vec) {
|
||||||
|
std::cout << "[";
|
||||||
|
for (const auto& val : row) {
|
||||||
|
std::cout << val << " ";
|
||||||
|
}
|
||||||
|
std::cout << "] " << std::endl;
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> DecisionTree::predict(std::vector<std::vector<int>>& X)
|
||||||
|
{
|
||||||
|
// Convert to tensor
|
||||||
|
long n = X.size();
|
||||||
|
long m = X.at(0).size();
|
||||||
|
torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X);
|
||||||
|
auto predictions = predict(X_tensor);
|
||||||
|
std::vector<int> result = platform::TensorUtils::to_vector<int>(predictions);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor DecisionTree::predict_proba(torch::Tensor& X)
|
||||||
|
{
|
||||||
|
if (!fitted) {
|
||||||
|
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
|
||||||
|
}
|
||||||
|
|
||||||
|
int n_samples = X.size(1);
|
||||||
|
torch::Tensor probabilities = torch::zeros({ n_samples, n_classes });
|
||||||
|
|
||||||
|
for (int i = 0; i < n_samples; i++) {
|
||||||
|
auto sample = X.index({ torch::indexing::Slice(), i }).ravel();
|
||||||
|
probabilities[i] = predictProbaSample(sample);
|
||||||
|
}
|
||||||
|
|
||||||
|
return probabilities;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<double>> DecisionTree::predict_proba(std::vector<std::vector<int>>& X)
|
||||||
|
{
|
||||||
|
auto n_samples = X.at(0).size();
|
||||||
|
// Convert to tensor
|
||||||
|
torch::Tensor X_tensor = platform::TensorUtils::to_matrix(X);
|
||||||
|
auto proba_tensor = predict_proba(X_tensor);
|
||||||
|
std::vector<std::vector<double>> result(n_samples, std::vector<double>(n_classes, 0.0));
|
||||||
|
|
||||||
|
for (int i = 0; i < n_samples; i++) {
|
||||||
|
for (int j = 0; j < n_classes; j++) {
|
||||||
|
result[i][j] = proba_tensor[i][j].item<double>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
int DecisionTree::predictSample(const torch::Tensor& x) const
|
||||||
|
{
|
||||||
|
if (!fitted) {
|
||||||
|
throw std::runtime_error(CLASSIFIER_NOT_FITTED);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (x.size(0) != n) { // n debería ser el número de características
|
||||||
|
throw std::runtime_error("Input sample has wrong number of features");
|
||||||
|
}
|
||||||
|
|
||||||
|
const TreeNode* leaf = traverseTree(x, root.get());
|
||||||
|
return leaf->predicted_class;
|
||||||
|
}
|
||||||
|
torch::Tensor DecisionTree::predictProbaSample(const torch::Tensor& x) const
|
||||||
|
{
|
||||||
|
const TreeNode* leaf = traverseTree(x, root.get());
|
||||||
|
return leaf->class_probabilities.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
const TreeNode* DecisionTree::traverseTree(const torch::Tensor& x, const TreeNode* node) const
|
||||||
|
{
|
||||||
|
if (!node) {
|
||||||
|
throw std::runtime_error("Null node encountered during tree traversal");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->is_leaf) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->split_feature < 0 || node->split_feature >= x.size(0)) {
|
||||||
|
throw std::runtime_error("Invalid split_feature index: " + std::to_string(node->split_feature));
|
||||||
|
}
|
||||||
|
|
||||||
|
int feature_value = x[node->split_feature].item<int>();
|
||||||
|
|
||||||
|
if (feature_value == node->split_value) {
|
||||||
|
if (!node->left) {
|
||||||
|
throw std::runtime_error("Missing left child in tree");
|
||||||
|
}
|
||||||
|
return traverseTree(x, node->left.get());
|
||||||
|
} else {
|
||||||
|
if (!node->right) {
|
||||||
|
throw std::runtime_error("Missing right child in tree");
|
||||||
|
}
|
||||||
|
return traverseTree(x, node->right.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> DecisionTree::graph(const std::string& title) const
|
||||||
|
{
|
||||||
|
std::vector<std::string> lines;
|
||||||
|
lines.push_back("digraph DecisionTree {");
|
||||||
|
lines.push_back(" rankdir=TB;");
|
||||||
|
lines.push_back(" node [shape=box, style=\"filled, rounded\", fontname=\"helvetica\"];");
|
||||||
|
lines.push_back(" edge [fontname=\"helvetica\"];");
|
||||||
|
|
||||||
|
if (!title.empty()) {
|
||||||
|
lines.push_back(" label=\"" + title + "\";");
|
||||||
|
lines.push_back(" labelloc=t;");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (root) {
|
||||||
|
int node_id = 0;
|
||||||
|
treeToGraph(root.get(), lines, node_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.push_back("}");
|
||||||
|
return lines;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecisionTree::treeToGraph(
|
||||||
|
const TreeNode* node,
|
||||||
|
std::vector<std::string>& lines,
|
||||||
|
int& node_id,
|
||||||
|
int parent_id,
|
||||||
|
const std::string& edge_label) const
|
||||||
|
{
|
||||||
|
|
||||||
|
int current_id = node_id++;
|
||||||
|
std::stringstream ss;
|
||||||
|
|
||||||
|
if (node->is_leaf) {
|
||||||
|
// Leaf node
|
||||||
|
ss << " node" << current_id << " [label=\"Class: " << node->predicted_class;
|
||||||
|
ss << "\\nProb: " << std::fixed << std::setprecision(3)
|
||||||
|
<< node->class_probabilities[node->predicted_class].item<float>();
|
||||||
|
ss << "\", fillcolor=\"lightblue\"];";
|
||||||
|
lines.push_back(ss.str());
|
||||||
|
} else {
|
||||||
|
// Internal node
|
||||||
|
ss << " node" << current_id << " [label=\"" << features[node->split_feature];
|
||||||
|
ss << " = " << node->split_value << "?\", fillcolor=\"lightgreen\"];";
|
||||||
|
lines.push_back(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add edge from parent
|
||||||
|
if (parent_id >= 0) {
|
||||||
|
ss.str("");
|
||||||
|
ss << " node" << parent_id << " -> node" << current_id;
|
||||||
|
if (!edge_label.empty()) {
|
||||||
|
ss << " [label=\"" << edge_label << "\"];";
|
||||||
|
} else {
|
||||||
|
ss << ";";
|
||||||
|
}
|
||||||
|
lines.push_back(ss.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse on children
|
||||||
|
if (!node->is_leaf) {
|
||||||
|
if (node->left) {
|
||||||
|
treeToGraph(node->left.get(), lines, node_id, current_id, "Yes");
|
||||||
|
}
|
||||||
|
if (node->right) {
|
||||||
|
treeToGraph(node->right.get(), lines, node_id, current_id, "No");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace bayesnet
|
129
src/experimental_clfs/DecisionTree.h
Normal file
129
src/experimental_clfs/DecisionTree.h
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
// ***************************************************************
|
||||||
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||||
|
// SPDX-FileType: SOURCE
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// ***************************************************************
|
||||||
|
|
||||||
|
#ifndef DECISION_TREE_H
|
||||||
|
#define DECISION_TREE_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include "bayesnet/classifiers/Classifier.h"
|
||||||
|
|
||||||
|
namespace bayesnet {
|
||||||
|
|
||||||
|
// Forward declaration
|
||||||
|
struct TreeNode;
|
||||||
|
|
||||||
|
class DecisionTree : public Classifier {
|
||||||
|
public:
|
||||||
|
explicit DecisionTree(int max_depth = 3, int min_samples_split = 2, int min_samples_leaf = 1);
|
||||||
|
virtual ~DecisionTree() = default;
|
||||||
|
|
||||||
|
// Override graph method to show tree structure
|
||||||
|
std::vector<std::string> graph(const std::string& title = "") const override;
|
||||||
|
|
||||||
|
// Setters for hyperparameters
|
||||||
|
void setMaxDepth(int depth) { max_depth = depth; checkValues(); }
|
||||||
|
void setMinSamplesSplit(int samples) { min_samples_split = samples; checkValues(); }
|
||||||
|
void setMinSamplesLeaf(int samples) { min_samples_leaf = samples; checkValues(); }
|
||||||
|
|
||||||
|
// Override setHyperparameters
|
||||||
|
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||||
|
|
||||||
|
torch::Tensor predict(torch::Tensor& X) override;
|
||||||
|
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
||||||
|
torch::Tensor predict_proba(torch::Tensor& X) override;
|
||||||
|
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
|
void trainModel(const torch::Tensor& weights, const Smoothing_t smoothing) override
|
||||||
|
{
|
||||||
|
// Decision trees do not require training in the traditional sense
|
||||||
|
// as they are built from the data directly.
|
||||||
|
// This method can be used to set weights or other parameters if needed.
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
void checkValues();
|
||||||
|
bool validateTensors(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& sample_weights) const;
|
||||||
|
// Tree hyperparameters
|
||||||
|
int max_depth;
|
||||||
|
int min_samples_split;
|
||||||
|
int min_samples_leaf;
|
||||||
|
int n_classes; // Number of classes in the target variable
|
||||||
|
|
||||||
|
// Root of the decision tree
|
||||||
|
std::unique_ptr<TreeNode> root;
|
||||||
|
|
||||||
|
// Build tree recursively
|
||||||
|
std::unique_ptr<TreeNode> buildTree(
|
||||||
|
const torch::Tensor& X,
|
||||||
|
const torch::Tensor& y,
|
||||||
|
const torch::Tensor& sample_weights,
|
||||||
|
int current_depth
|
||||||
|
);
|
||||||
|
|
||||||
|
// Find best split for a node
|
||||||
|
struct SplitInfo {
|
||||||
|
int feature_index;
|
||||||
|
int split_value;
|
||||||
|
double impurity_decrease;
|
||||||
|
torch::Tensor left_mask;
|
||||||
|
torch::Tensor right_mask;
|
||||||
|
};
|
||||||
|
|
||||||
|
SplitInfo findBestSplit(
|
||||||
|
const torch::Tensor& X,
|
||||||
|
const torch::Tensor& y,
|
||||||
|
const torch::Tensor& sample_weights
|
||||||
|
);
|
||||||
|
|
||||||
|
// Calculate weighted Gini impurity for multi-class
|
||||||
|
double calculateGiniImpurity(
|
||||||
|
const torch::Tensor& y,
|
||||||
|
const torch::Tensor& sample_weights
|
||||||
|
);
|
||||||
|
|
||||||
|
// Make predictions for a single sample
|
||||||
|
int predictSample(const torch::Tensor& x) const;
|
||||||
|
|
||||||
|
// Make probabilistic predictions for a single sample
|
||||||
|
torch::Tensor predictProbaSample(const torch::Tensor& x) const;
|
||||||
|
|
||||||
|
// Traverse tree to find leaf node
|
||||||
|
const TreeNode* traverseTree(const torch::Tensor& x, const TreeNode* node) const;
|
||||||
|
|
||||||
|
// Convert tree to graph representation
|
||||||
|
void treeToGraph(
|
||||||
|
const TreeNode* node,
|
||||||
|
std::vector<std::string>& lines,
|
||||||
|
int& node_id,
|
||||||
|
int parent_id = -1,
|
||||||
|
const std::string& edge_label = ""
|
||||||
|
) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Tree node structure
|
||||||
|
struct TreeNode {
|
||||||
|
bool is_leaf;
|
||||||
|
|
||||||
|
// For internal nodes
|
||||||
|
int split_feature;
|
||||||
|
int split_value;
|
||||||
|
std::unique_ptr<TreeNode> left;
|
||||||
|
std::unique_ptr<TreeNode> right;
|
||||||
|
|
||||||
|
// For leaf nodes
|
||||||
|
int predicted_class;
|
||||||
|
torch::Tensor class_probabilities; // Probability for each class
|
||||||
|
|
||||||
|
TreeNode() : is_leaf(false), split_feature(-1), split_value(-1), predicted_class(-1) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace bayesnet
|
||||||
|
|
||||||
|
#endif // DECISION_TREE_H
|
142
src/experimental_clfs/README.md
Normal file
142
src/experimental_clfs/README.md
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
# AdaBoost and DecisionTree Classifier Implementation
|
||||||
|
|
||||||
|
This implementation provides both a Decision Tree classifier and a multi-class AdaBoost classifier based on the SAMME (Stagewise Additive Modeling using a Multi-class Exponential loss) algorithm described in the paper "Multi-class AdaBoost" by Zhu et al. Implemented in C++ using <https://claude.ai>
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
### 1. DecisionTree Classifier
|
||||||
|
|
||||||
|
A classic decision tree implementation that:
|
||||||
|
|
||||||
|
- Supports multi-class classification
|
||||||
|
- Handles weighted samples (essential for boosting)
|
||||||
|
- Uses Gini impurity as the splitting criterion
|
||||||
|
- Works with discrete/categorical features
|
||||||
|
- Provides both class predictions and probability estimates
|
||||||
|
|
||||||
|
#### Key Features
|
||||||
|
|
||||||
|
- **Max Depth Control**: Limit tree depth to create weak learners
|
||||||
|
- **Minimum Samples**: Control minimum samples for splitting and leaf nodes
|
||||||
|
- **Weighted Training**: Properly handles sample weights for boosting
|
||||||
|
- **Visualization**: Generates DOT format graphs of the tree structure
|
||||||
|
|
||||||
|
#### Hyperparameters
|
||||||
|
|
||||||
|
- `max_depth`: Maximum depth of the tree (default: 3)
|
||||||
|
- `min_samples_split`: Minimum samples required to split a node (default: 2)
|
||||||
|
- `min_samples_leaf`: Minimum samples required in a leaf node (default: 1)
|
||||||
|
|
||||||
|
### 2. AdaBoost Classifier
|
||||||
|
|
||||||
|
A multi-class AdaBoost implementation using DecisionTree as base estimators:
|
||||||
|
|
||||||
|
- **SAMME Algorithm**: Implements the multi-class extension of AdaBoost
|
||||||
|
- **Automatic Stumps**: Uses decision stumps (max_depth=1) by default
|
||||||
|
- **Early Stopping**: Stops if base classifier performs worse than random
|
||||||
|
- **Ensemble Visualization**: Shows the weighted combination of base estimators
|
||||||
|
|
||||||
|
#### Key Features
|
||||||
|
|
||||||
|
- **Multi-class Support**: Natural extension to K classes
|
||||||
|
- **Base Estimator Control**: Configure depth of base decision trees
|
||||||
|
- **Training Monitoring**: Track training errors and estimator weights
|
||||||
|
- **Probability Estimates**: Provides class probability predictions
|
||||||
|
|
||||||
|
#### Hyperparameters
|
||||||
|
|
||||||
|
- `n_estimators`: Number of base estimators to train (default: 50)
|
||||||
|
- `base_max_depth`: Maximum depth for base decision trees (default: 1)
|
||||||
|
|
||||||
|
## Algorithm Details
|
||||||
|
|
||||||
|
The SAMME algorithm differs from binary AdaBoost in the calculation of the estimator weight (alpha):
|
||||||
|
|
||||||
|
```
|
||||||
|
α = log((1 - err) / err) + log(K - 1)
|
||||||
|
```
|
||||||
|
|
||||||
|
where `K` is the number of classes. This formula ensures that:
|
||||||
|
|
||||||
|
- When K = 2, it reduces to standard AdaBoost
|
||||||
|
- For K > 2, base classifiers only need to be better than random guessing (1/K) rather than 50%
|
||||||
|
|
||||||
|
## Usage Example
|
||||||
|
|
||||||
|
```cpp
|
||||||
|
// Create AdaBoost with decision stumps
|
||||||
|
AdaBoost ada(100, 1); // 100 estimators, max_depth=1
|
||||||
|
|
||||||
|
// Train
|
||||||
|
ada.fit(X_train, y_train, features, className, states, Smoothing_t::NONE);
|
||||||
|
|
||||||
|
// Predict
|
||||||
|
auto predictions = ada.predict(X_test);
|
||||||
|
auto probabilities = ada.predict_proba(X_test);
|
||||||
|
|
||||||
|
// Evaluate
|
||||||
|
float accuracy = ada.score(X_test, y_test);
|
||||||
|
|
||||||
|
// Get ensemble information
|
||||||
|
auto weights = ada.getEstimatorWeights();
|
||||||
|
auto errors = ada.getTrainingErrors();
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
AdaBoost (inherits from Ensemble)
|
||||||
|
└── Uses multiple DecisionTree instances as base estimators
|
||||||
|
└── DecisionTree (inherits from Classifier)
|
||||||
|
└── Implements weighted Gini impurity splitting
|
||||||
|
```
|
||||||
|
|
||||||
|
## Visualization
|
||||||
|
|
||||||
|
Both classifiers support graph visualization:
|
||||||
|
|
||||||
|
- **DecisionTree**: Shows the tree structure with split conditions
|
||||||
|
- **AdaBoost**: Shows the ensemble of weighted base estimators
|
||||||
|
|
||||||
|
Generate visualizations using:
|
||||||
|
|
||||||
|
```cpp
|
||||||
|
auto graph = classifier.graph("Title");
|
||||||
|
```
|
||||||
|
|
||||||
|
## Data Format
|
||||||
|
|
||||||
|
Both classifiers expect discrete/categorical data:
|
||||||
|
|
||||||
|
- **Features**: Integer values representing categories (stored in `torch::Tensor` or `std::vector<std::vector<int>>`)
|
||||||
|
- **Labels**: Integer values representing class indices (0, 1, ..., K-1)
|
||||||
|
- **States**: Map defining possible values for each feature and the class variable
|
||||||
|
- **Sample Weights**: Optional weights for each training sample (important for boosting)
|
||||||
|
|
||||||
|
Example data setup:
|
||||||
|
|
||||||
|
```cpp
|
||||||
|
// Features matrix (n_features x n_samples)
|
||||||
|
torch::Tensor X = torch::tensor({{0, 1, 2}, {1, 0, 1}}); // 2 features, 3 samples
|
||||||
|
|
||||||
|
// Labels vector
|
||||||
|
torch::Tensor y = torch::tensor({0, 1, 0}); // 3 samples
|
||||||
|
|
||||||
|
// States definition
|
||||||
|
std::map<std::string, std::vector<int>> states;
|
||||||
|
states["feature1"] = {0, 1, 2}; // Feature 1 can take values 0, 1, or 2
|
||||||
|
states["feature2"] = {0, 1}; // Feature 2 can take values 0 or 1
|
||||||
|
states["class"] = {0, 1}; // Binary classification
|
||||||
|
```
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- The implementation handles discrete/categorical features as indicated by the int-based data structures
|
||||||
|
- Sample weights are properly propagated through the tree building process
|
||||||
|
- The DecisionTree implementation uses equality testing for splits (suitable for categorical data)
|
||||||
|
- Both classifiers support the standard fit/predict interface from the base framework
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- Zhu, J., Zou, H., Rosset, S., & Hastie, T. (2009). Multi-class AdaBoost. Statistics and its interface, 2(3), 349-360.
|
||||||
|
- Breiman, L., Friedman, J., Olshen, R., & Stone, C. (1984). Classification and Regression Trees. Wadsworth, Belmont, CA.
|
@@ -45,6 +45,19 @@ namespace platform {
|
|||||||
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
static torch::Tensor to_matrix(const std::vector<std::vector<int>>& data)
|
||||||
|
{
|
||||||
|
if (data.empty()) return torch::empty({ 0, 0 }, torch::kInt64);
|
||||||
|
size_t rows = data.size();
|
||||||
|
size_t cols = data[0].size();
|
||||||
|
torch::Tensor tensor = torch::empty({ static_cast<long>(rows), static_cast<long>(cols) }, torch::kInt64);
|
||||||
|
for (size_t i = 0; i < rows; ++i) {
|
||||||
|
for (size_t j = 0; j < cols; ++j) {
|
||||||
|
tensor.index_put_({ static_cast<long>(i), static_cast<long>(j) }, data[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -23,10 +23,11 @@
|
|||||||
#include <pyclassifiers/ODTE.h>
|
#include <pyclassifiers/ODTE.h>
|
||||||
#include <pyclassifiers/SVC.h>
|
#include <pyclassifiers/SVC.h>
|
||||||
#include <pyclassifiers/XGBoost.h>
|
#include <pyclassifiers/XGBoost.h>
|
||||||
#include <pyclassifiers/AdaBoost.h>
|
#include <pyclassifiers/AdaBoostPy.h>
|
||||||
#include <pyclassifiers/RandomForest.h>
|
#include <pyclassifiers/RandomForest.h>
|
||||||
#include "../experimental_clfs/XA1DE.h"
|
#include "../experimental_clfs/XA1DE.h"
|
||||||
#include "../experimental_clfs/AdaBoost.h"
|
// #include "../experimental_clfs/AdaBoost.h"
|
||||||
|
#include "../experimental_clfs/DecisionTree.h"
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
class Models {
|
class Models {
|
||||||
|
@@ -35,10 +35,12 @@ namespace platform {
|
|||||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();});
|
[](void) -> bayesnet::BaseClassifier* { return new pywrap::RandomForest();});
|
||||||
static Registrar registrarXGB("XGBoost",
|
static Registrar registrarXGB("XGBoost",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();});
|
[](void) -> bayesnet::BaseClassifier* { return new pywrap::XGBoost();});
|
||||||
static Registrar registrarAda("AdaBoostPy",
|
static Registrar registrarAdaPy("AdaBoostPy",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoost();});
|
[](void) -> bayesnet::BaseClassifier* { return new pywrap::AdaBoostPy();});
|
||||||
// static Registrar registrarAda2("AdaBoost",
|
// static Registrar registrarAda("AdaBoost",
|
||||||
// [](void) -> bayesnet::BaseClassifier* { return new platform::AdaBoost();});
|
// [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AdaBoost();});
|
||||||
|
static Registrar registrarDT("DecisionTree",
|
||||||
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::DecisionTree();});
|
||||||
static Registrar registrarXSPODE("XSPODE",
|
static Registrar registrarXSPODE("XSPODE",
|
||||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpode(0);});
|
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::XSpode(0);});
|
||||||
static Registrar registrarXSP2DE("XSP2DE",
|
static Registrar registrarXSP2DE("XSP2DE",
|
||||||
|
@@ -12,11 +12,11 @@ if(ENABLE_TESTING)
|
|||||||
${Bayesnet_INCLUDE_DIRS}
|
${Bayesnet_INCLUDE_DIRS}
|
||||||
)
|
)
|
||||||
set(TEST_SOURCES_PLATFORM
|
set(TEST_SOURCES_PLATFORM
|
||||||
TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp
|
TestUtils.cpp TestPlatform.cpp TestResult.cpp TestScores.cpp TestDecisionTree.cpp
|
||||||
${Platform_SOURCE_DIR}/src/common/Datasets.cpp ${Platform_SOURCE_DIR}/src/common/Dataset.cpp ${Platform_SOURCE_DIR}/src/common/Discretization.cpp
|
${Platform_SOURCE_DIR}/src/common/Datasets.cpp ${Platform_SOURCE_DIR}/src/common/Dataset.cpp ${Platform_SOURCE_DIR}/src/common/Discretization.cpp
|
||||||
${Platform_SOURCE_DIR}/src/main/Scores.cpp
|
${Platform_SOURCE_DIR}/src/main/Scores.cpp ${Platform_SOURCE_DIR}/src/experimental_clfs/DecisionTree.cpp
|
||||||
)
|
)
|
||||||
add_executable(${TEST_PLATFORM} ${TEST_SOURCES_PLATFORM})
|
add_executable(${TEST_PLATFORM} ${TEST_SOURCES_PLATFORM})
|
||||||
target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" mdlp Catch2::Catch2WithMain BayesNet)
|
target_link_libraries(${TEST_PLATFORM} PUBLIC "${TORCH_LIBRARIES}" fimdlp Catch2::Catch2WithMain bayesnet)
|
||||||
add_test(NAME ${TEST_PLATFORM} COMMAND ${TEST_PLATFORM})
|
add_test(NAME ${TEST_PLATFORM} COMMAND ${TEST_PLATFORM})
|
||||||
endif(ENABLE_TESTING)
|
endif(ENABLE_TESTING)
|
||||||
|
303
tests/TestDecisionTree.cpp
Normal file
303
tests/TestDecisionTree.cpp
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
// ***************************************************************
|
||||||
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||||
|
// SPDX-FileType: SOURCE
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
// ***************************************************************
|
||||||
|
|
||||||
|
#include <catch2/catch_test_macros.hpp>
|
||||||
|
#include <catch2/catch_approx.hpp>
|
||||||
|
#include <catch2/matchers/catch_matchers_string.hpp>
|
||||||
|
#include <catch2/matchers/catch_matchers_vector.hpp>
|
||||||
|
#include <torch/torch.h>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include "experimental_clfs/DecisionTree.h"
|
||||||
|
#include "TestUtils.h"
|
||||||
|
|
||||||
|
using namespace bayesnet;
|
||||||
|
using namespace Catch::Matchers;
|
||||||
|
|
||||||
|
TEST_CASE("DecisionTree Construction", "[DecisionTree]")
|
||||||
|
{
|
||||||
|
SECTION("Default constructor")
|
||||||
|
{
|
||||||
|
REQUIRE_NOTHROW(DecisionTree());
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Constructor with parameters")
|
||||||
|
{
|
||||||
|
REQUIRE_NOTHROW(DecisionTree(5, 10, 3));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("DecisionTree Hyperparameter Setting", "[DecisionTree]")
|
||||||
|
{
|
||||||
|
DecisionTree dt;
|
||||||
|
|
||||||
|
SECTION("Set individual hyperparameters")
|
||||||
|
{
|
||||||
|
REQUIRE_NOTHROW(dt.setMaxDepth(10));
|
||||||
|
REQUIRE_NOTHROW(dt.setMinSamplesSplit(5));
|
||||||
|
REQUIRE_NOTHROW(dt.setMinSamplesLeaf(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Set hyperparameters via JSON")
|
||||||
|
{
|
||||||
|
nlohmann::json params;
|
||||||
|
params["max_depth"] = 7;
|
||||||
|
params["min_samples_split"] = 4;
|
||||||
|
params["min_samples_leaf"] = 2;
|
||||||
|
|
||||||
|
REQUIRE_NOTHROW(dt.setHyperparameters(params));
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Invalid hyperparameters should throw")
|
||||||
|
{
|
||||||
|
nlohmann::json params;
|
||||||
|
|
||||||
|
// Negative max_depth
|
||||||
|
params["max_depth"] = -1;
|
||||||
|
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
|
||||||
|
|
||||||
|
// Zero min_samples_split
|
||||||
|
params["max_depth"] = 5;
|
||||||
|
params["min_samples_split"] = 0;
|
||||||
|
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
|
||||||
|
|
||||||
|
// Negative min_samples_leaf
|
||||||
|
params["min_samples_split"] = 2;
|
||||||
|
params["min_samples_leaf"] = -5;
|
||||||
|
REQUIRE_THROWS_AS(dt.setHyperparameters(params), std::invalid_argument);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("DecisionTree Basic Functionality", "[DecisionTree]")
|
||||||
|
{
|
||||||
|
// Create a simple dataset
|
||||||
|
int n_samples = 20;
|
||||||
|
int n_features = 2;
|
||||||
|
|
||||||
|
std::vector<std::vector<int>> X(n_features, std::vector<int>(n_samples));
|
||||||
|
std::vector<int> y(n_samples);
|
||||||
|
|
||||||
|
// Simple pattern: class depends on first feature
|
||||||
|
for (int i = 0; i < n_samples; i++) {
|
||||||
|
X[0][i] = i < 10 ? 0 : 1;
|
||||||
|
X[1][i] = i % 2;
|
||||||
|
y[i] = X[0][i]; // Class equals first feature
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> features = { "f1", "f2" };
|
||||||
|
std::string className = "class";
|
||||||
|
std::map<std::string, std::vector<int>> states;
|
||||||
|
states["f1"] = { 0, 1 };
|
||||||
|
states["f2"] = { 0, 1 };
|
||||||
|
states["class"] = { 0, 1 };
|
||||||
|
|
||||||
|
SECTION("Training with vector interface")
|
||||||
|
{
|
||||||
|
DecisionTree dt(3, 2, 1);
|
||||||
|
REQUIRE_NOTHROW(dt.fit(X, y, features, className, states, Smoothing_t::NONE));
|
||||||
|
|
||||||
|
auto predictions = dt.predict(X);
|
||||||
|
REQUIRE(predictions.size() == static_cast<size_t>(n_samples));
|
||||||
|
|
||||||
|
// Should achieve perfect accuracy on this simple dataset
|
||||||
|
int correct = 0;
|
||||||
|
for (size_t i = 0; i < predictions.size(); i++) {
|
||||||
|
if (predictions[i] == y[i]) correct++;
|
||||||
|
}
|
||||||
|
REQUIRE(correct == n_samples);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Prediction before fitting")
|
||||||
|
{
|
||||||
|
DecisionTree dt;
|
||||||
|
REQUIRE_THROWS_WITH(dt.predict(X),
|
||||||
|
ContainsSubstring("Classifier has not been fitted"));
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Probability predictions")
|
||||||
|
{
|
||||||
|
DecisionTree dt(3, 2, 1);
|
||||||
|
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
|
||||||
|
|
||||||
|
auto proba = dt.predict_proba(X);
|
||||||
|
REQUIRE(proba.size() == static_cast<size_t>(n_samples));
|
||||||
|
REQUIRE(proba[0].size() == 2); // Two classes
|
||||||
|
|
||||||
|
// Check probabilities sum to 1 and probabilities are valid
|
||||||
|
auto predictions = dt.predict(X);
|
||||||
|
for (size_t i = 0; i < proba.size(); i++) {
|
||||||
|
auto p = proba[i];
|
||||||
|
auto pred = predictions[i];
|
||||||
|
REQUIRE(p.size() == 2);
|
||||||
|
REQUIRE(p[0] >= 0.0);
|
||||||
|
REQUIRE(p[1] >= 0.0);
|
||||||
|
double sum = p[0] + p[1];
|
||||||
|
//Check that prodict_proba matches the expected predict value
|
||||||
|
REQUIRE(pred == (p[0] > p[1] ? 0 : 1));
|
||||||
|
REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("DecisionTree on Iris Dataset", "[DecisionTree][iris]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
|
||||||
|
SECTION("Training with dataset format")
|
||||||
|
{
|
||||||
|
DecisionTree dt(5, 2, 1);
|
||||||
|
|
||||||
|
INFO("Dataset shape: " << raw.dataset.sizes());
|
||||||
|
INFO("Features: " << raw.featurest.size());
|
||||||
|
INFO("Samples: " << raw.nSamples);
|
||||||
|
|
||||||
|
// DecisionTree expects dataset in format: features x samples, with labels as last row
|
||||||
|
REQUIRE_NOTHROW(dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE));
|
||||||
|
|
||||||
|
// Test prediction
|
||||||
|
auto predictions = dt.predict(raw.Xt);
|
||||||
|
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||||
|
|
||||||
|
// Calculate accuracy
|
||||||
|
auto correct = torch::sum(predictions == raw.yt).item<int>();
|
||||||
|
double accuracy = static_cast<double>(correct) / raw.yt.size(0);
|
||||||
|
REQUIRE(accuracy > 0.97); // Reasonable accuracy for Iris
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Training with vector interface")
|
||||||
|
{
|
||||||
|
DecisionTree dt(5, 2, 1);
|
||||||
|
|
||||||
|
REQUIRE_NOTHROW(dt.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv, Smoothing_t::NONE));
|
||||||
|
|
||||||
|
// std::cout << "Tree structure:\n";
|
||||||
|
// auto graph_lines = dt.graph("Iris Decision Tree");
|
||||||
|
// for (const auto& line : graph_lines) {
|
||||||
|
// std::cout << line << "\n";
|
||||||
|
// }
|
||||||
|
auto predictions = dt.predict(raw.Xv);
|
||||||
|
REQUIRE(predictions.size() == raw.yv.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Different tree depths")
|
||||||
|
{
|
||||||
|
std::vector<int> depths = { 1, 3, 5 };
|
||||||
|
|
||||||
|
for (int depth : depths) {
|
||||||
|
DecisionTree dt(depth, 2, 1);
|
||||||
|
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
|
||||||
|
|
||||||
|
auto predictions = dt.predict(raw.Xt);
|
||||||
|
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("DecisionTree Edge Cases", "[DecisionTree]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
|
||||||
|
SECTION("Very shallow tree")
|
||||||
|
{
|
||||||
|
DecisionTree dt(1, 2, 1); // depth = 1
|
||||||
|
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
|
||||||
|
|
||||||
|
auto predictions = dt.predict(raw.Xt);
|
||||||
|
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||||
|
|
||||||
|
// With depth 1, should have at most 2 unique predictions
|
||||||
|
auto unique_vals = at::_unique(predictions);
|
||||||
|
REQUIRE(std::get<0>(unique_vals).size(0) <= 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("High min_samples_split")
|
||||||
|
{
|
||||||
|
DecisionTree dt(10, 50, 1);
|
||||||
|
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
|
||||||
|
|
||||||
|
auto predictions = dt.predict(raw.Xt);
|
||||||
|
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("DecisionTree Graph Visualization", "[DecisionTree]")
|
||||||
|
{
|
||||||
|
// Simple dataset
|
||||||
|
std::vector<std::vector<int>> X = { {0,0,0,1}, {0,1,1,1} }; // XOR pattern
|
||||||
|
std::vector<int> y = { 0, 1, 1, 0 }; // XOR pattern
|
||||||
|
std::vector<std::string> features = { "x1", "x2" };
|
||||||
|
std::string className = "xor";
|
||||||
|
std::map<std::string, std::vector<int>> states;
|
||||||
|
states["x1"] = { 0, 1 };
|
||||||
|
states["x2"] = { 0, 1 };
|
||||||
|
states["xor"] = { 0, 1 };
|
||||||
|
|
||||||
|
SECTION("Graph generation")
|
||||||
|
{
|
||||||
|
DecisionTree dt(2, 1, 1);
|
||||||
|
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
|
||||||
|
|
||||||
|
auto graph_lines = dt.graph();
|
||||||
|
|
||||||
|
REQUIRE(graph_lines.size() > 2);
|
||||||
|
REQUIRE(graph_lines.front() == "digraph DecisionTree {");
|
||||||
|
REQUIRE(graph_lines.back() == "}");
|
||||||
|
|
||||||
|
// Should contain node definitions
|
||||||
|
bool has_nodes = false;
|
||||||
|
for (const auto& line : graph_lines) {
|
||||||
|
if (line.find("node") != std::string::npos) {
|
||||||
|
has_nodes = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
REQUIRE(has_nodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Graph with title")
|
||||||
|
{
|
||||||
|
DecisionTree dt(2, 1, 1);
|
||||||
|
dt.fit(X, y, features, className, states, Smoothing_t::NONE);
|
||||||
|
|
||||||
|
auto graph_lines = dt.graph("XOR Tree");
|
||||||
|
|
||||||
|
bool has_title = false;
|
||||||
|
for (const auto& line : graph_lines) {
|
||||||
|
if (line.find("label=\"XOR Tree\"") != std::string::npos) {
|
||||||
|
has_title = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
REQUIRE(has_title);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("DecisionTree with Weights", "[DecisionTree]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
|
||||||
|
SECTION("Uniform weights")
|
||||||
|
{
|
||||||
|
DecisionTree dt(5, 2, 1);
|
||||||
|
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, raw.weights, Smoothing_t::NONE);
|
||||||
|
|
||||||
|
auto predictions = dt.predict(raw.Xt);
|
||||||
|
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
SECTION("Non-uniform weights")
|
||||||
|
{
|
||||||
|
auto weights = torch::ones({ raw.nSamples });
|
||||||
|
weights.index({ torch::indexing::Slice(0, 50) }) *= 2.0; // Emphasize first class
|
||||||
|
weights = weights / weights.sum();
|
||||||
|
|
||||||
|
DecisionTree dt(5, 2, 1);
|
||||||
|
dt.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, weights, Smoothing_t::NONE);
|
||||||
|
|
||||||
|
auto predictions = dt.predict(raw.Xt);
|
||||||
|
REQUIRE(predictions.size(0) == raw.yt.size(0));
|
||||||
|
}
|
||||||
|
}
|
@@ -7,7 +7,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include "TestUtils.h"
|
#include "TestUtils.h"
|
||||||
#include "folding.hpp"
|
#include "folding.hpp"
|
||||||
#include <ArffFiles.hpp>
|
#include <ArffFiles/ArffFiles.hpp>
|
||||||
#include <bayesnet/classifiers/TAN.h>
|
#include <bayesnet/classifiers/TAN.h>
|
||||||
#include "config_platform.h"
|
#include "config_platform.h"
|
||||||
|
|
||||||
@@ -20,17 +20,17 @@ TEST_CASE("Test Platform version", "[Platform]")
|
|||||||
TEST_CASE("Test Folding library version", "[Folding]")
|
TEST_CASE("Test Folding library version", "[Folding]")
|
||||||
{
|
{
|
||||||
std::string version = folding::KFold(5, 100).version();
|
std::string version = folding::KFold(5, 100).version();
|
||||||
REQUIRE(version == "1.1.0");
|
REQUIRE(version == "1.1.1");
|
||||||
}
|
}
|
||||||
TEST_CASE("Test BayesNet version", "[BayesNet]")
|
TEST_CASE("Test BayesNet version", "[BayesNet]")
|
||||||
{
|
{
|
||||||
std::string version = bayesnet::TAN().getVersion();
|
std::string version = bayesnet::TAN().getVersion();
|
||||||
REQUIRE(version == "1.0.6");
|
REQUIRE(version == "1.1.2");
|
||||||
}
|
}
|
||||||
TEST_CASE("Test mdlp version", "[mdlp]")
|
TEST_CASE("Test mdlp version", "[mdlp]")
|
||||||
{
|
{
|
||||||
std::string version = mdlp::CPPFImdlp::version();
|
std::string version = mdlp::CPPFImdlp::version();
|
||||||
REQUIRE(version == "2.0.0");
|
REQUIRE(version == "2.0.1");
|
||||||
}
|
}
|
||||||
TEST_CASE("Test Arff version", "[Arff]")
|
TEST_CASE("Test Arff version", "[Arff]")
|
||||||
{
|
{
|
||||||
|
@@ -14,38 +14,40 @@
|
|||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
auto epsilon = 1e-4;
|
auto epsilon = 1e-4;
|
||||||
|
|
||||||
void make_test_bin(int TP, int TN, int FP, int FN, std::vector<int>& y_test, std::vector<int>& y_pred)
|
void make_test_bin(int TP, int TN, int FP, int FN, std::vector<int>& y_test, torch::Tensor& y_pred)
|
||||||
{
|
{
|
||||||
// TP
|
std::vector<std::array<double, 2>> probs;
|
||||||
|
// TP: true positive (label 1, predicted 1)
|
||||||
for (int i = 0; i < TP; i++) {
|
for (int i = 0; i < TP; i++) {
|
||||||
y_test.push_back(1);
|
y_test.push_back(1);
|
||||||
y_pred.push_back(1);
|
probs.push_back({ 0.0, 1.0 }); // P(class 0)=0, P(class 1)=1
|
||||||
}
|
}
|
||||||
// TN
|
// TN: true negative (label 0, predicted 0)
|
||||||
for (int i = 0; i < TN; i++) {
|
for (int i = 0; i < TN; i++) {
|
||||||
y_test.push_back(0);
|
y_test.push_back(0);
|
||||||
y_pred.push_back(0);
|
probs.push_back({ 1.0, 0.0 }); // P(class 0)=1, P(class 1)=0
|
||||||
}
|
}
|
||||||
// FP
|
// FP: false positive (label 0, predicted 1)
|
||||||
for (int i = 0; i < FP; i++) {
|
for (int i = 0; i < FP; i++) {
|
||||||
y_test.push_back(0);
|
y_test.push_back(0);
|
||||||
y_pred.push_back(1);
|
probs.push_back({ 0.0, 1.0 }); // P(class 0)=0, P(class 1)=1
|
||||||
}
|
}
|
||||||
// FN
|
// FN: false negative (label 1, predicted 0)
|
||||||
for (int i = 0; i < FN; i++) {
|
for (int i = 0; i < FN; i++) {
|
||||||
y_test.push_back(1);
|
y_test.push_back(1);
|
||||||
y_pred.push_back(0);
|
probs.push_back({ 1.0, 0.0 }); // P(class 0)=1, P(class 1)=0
|
||||||
}
|
}
|
||||||
|
// Convert to torch::Tensor of double, shape [N,2]
|
||||||
|
y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 2 }, torch::kFloat64).clone();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("Scores binary", "[Scores]")
|
TEST_CASE("Scores binary", "[Scores]")
|
||||||
{
|
{
|
||||||
std::vector<int> y_test;
|
std::vector<int> y_test;
|
||||||
std::vector<int> y_pred;
|
torch::Tensor y_pred;
|
||||||
make_test_bin(197, 210, 52, 41, y_test, y_pred);
|
make_test_bin(197, 210, 52, 41, y_test, y_pred);
|
||||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
platform::Scores scores(y_test_tensor, y_pred, 2);
|
||||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 2);
|
|
||||||
REQUIRE(scores.accuracy() == Catch::Approx(0.814).epsilon(epsilon));
|
REQUIRE(scores.accuracy() == Catch::Approx(0.814).epsilon(epsilon));
|
||||||
REQUIRE(scores.f1_score(0) == Catch::Approx(0.818713));
|
REQUIRE(scores.f1_score(0) == Catch::Approx(0.818713));
|
||||||
REQUIRE(scores.f1_score(1) == Catch::Approx(0.809035));
|
REQUIRE(scores.f1_score(1) == Catch::Approx(0.809035));
|
||||||
@@ -64,10 +66,23 @@ TEST_CASE("Scores binary", "[Scores]")
|
|||||||
TEST_CASE("Scores multiclass", "[Scores]")
|
TEST_CASE("Scores multiclass", "[Scores]")
|
||||||
{
|
{
|
||||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
// Refactor y_pred to a tensor of shape [10, 3] with probabilities
|
||||||
|
std::vector<std::array<double, 3>> probs = {
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
};
|
||||||
|
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||||
|
// Convert y_test to a tensor
|
||||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
platform::Scores scores(y_test_tensor, y_pred, 3);
|
||||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3);
|
|
||||||
REQUIRE(scores.accuracy() == Catch::Approx(0.6).epsilon(epsilon));
|
REQUIRE(scores.accuracy() == Catch::Approx(0.6).epsilon(epsilon));
|
||||||
REQUIRE(scores.f1_score(0) == Catch::Approx(0.666667));
|
REQUIRE(scores.f1_score(0) == Catch::Approx(0.666667));
|
||||||
REQUIRE(scores.f1_score(1) == Catch::Approx(0.4));
|
REQUIRE(scores.f1_score(1) == Catch::Approx(0.4));
|
||||||
@@ -84,10 +99,21 @@ TEST_CASE("Scores multiclass", "[Scores]")
|
|||||||
TEST_CASE("Test Confusion Matrix Values", "[Scores]")
|
TEST_CASE("Test Confusion Matrix Values", "[Scores]")
|
||||||
{
|
{
|
||||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
std::vector<std::array<double, 3>> probs = {
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
};
|
||||||
|
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
platform::Scores scores(y_test_tensor, y_pred, 3);
|
||||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3);
|
|
||||||
auto confusion_matrix = scores.get_confusion_matrix();
|
auto confusion_matrix = scores.get_confusion_matrix();
|
||||||
REQUIRE(confusion_matrix[0][0].item<int>() == 2);
|
REQUIRE(confusion_matrix[0][0].item<int>() == 2);
|
||||||
REQUIRE(confusion_matrix[0][1].item<int>() == 1);
|
REQUIRE(confusion_matrix[0][1].item<int>() == 1);
|
||||||
@@ -102,11 +128,22 @@ TEST_CASE("Test Confusion Matrix Values", "[Scores]")
|
|||||||
TEST_CASE("Confusion Matrix JSON", "[Scores]")
|
TEST_CASE("Confusion Matrix JSON", "[Scores]")
|
||||||
{
|
{
|
||||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
std::vector<std::array<double, 3>> probs = {
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
};
|
||||||
|
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
|
||||||
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
||||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||||
auto res_json_int = scores.get_confusion_matrix_json();
|
auto res_json_int = scores.get_confusion_matrix_json();
|
||||||
REQUIRE(res_json_int[0][0] == 2);
|
REQUIRE(res_json_int[0][0] == 2);
|
||||||
REQUIRE(res_json_int[0][1] == 1);
|
REQUIRE(res_json_int[0][1] == 1);
|
||||||
@@ -131,11 +168,22 @@ TEST_CASE("Confusion Matrix JSON", "[Scores]")
|
|||||||
TEST_CASE("Classification Report", "[Scores]")
|
TEST_CASE("Classification Report", "[Scores]")
|
||||||
{
|
{
|
||||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
std::vector<std::array<double, 3>> probs = {
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
};
|
||||||
|
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
|
||||||
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
||||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||||
auto report = scores.classification_report(Colors::BLUE(), "train");
|
auto report = scores.classification_report(Colors::BLUE(), "train");
|
||||||
auto json_matrix = scores.get_confusion_matrix_json(true);
|
auto json_matrix = scores.get_confusion_matrix_json(true);
|
||||||
platform::Scores scores2(json_matrix);
|
platform::Scores scores2(json_matrix);
|
||||||
@@ -144,11 +192,22 @@ TEST_CASE("Classification Report", "[Scores]")
|
|||||||
TEST_CASE("JSON constructor", "[Scores]")
|
TEST_CASE("JSON constructor", "[Scores]")
|
||||||
{
|
{
|
||||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
std::vector<std::array<double, 3>> probs = {
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
};
|
||||||
|
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
|
||||||
std::vector<std::string> labels = { "Car", "Boat", "Aeroplane" };
|
std::vector<std::string> labels = { "Car", "Boat", "Aeroplane" };
|
||||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||||
auto res_json_int = scores.get_confusion_matrix_json();
|
auto res_json_int = scores.get_confusion_matrix_json();
|
||||||
platform::Scores scores2(res_json_int);
|
platform::Scores scores2(res_json_int);
|
||||||
REQUIRE(scores.accuracy() == scores2.accuracy());
|
REQUIRE(scores.accuracy() == scores2.accuracy());
|
||||||
@@ -173,17 +232,14 @@ TEST_CASE("JSON constructor", "[Scores]")
|
|||||||
TEST_CASE("Aggregate", "[Scores]")
|
TEST_CASE("Aggregate", "[Scores]")
|
||||||
{
|
{
|
||||||
std::vector<int> y_test;
|
std::vector<int> y_test;
|
||||||
std::vector<int> y_pred;
|
torch::Tensor y_pred;
|
||||||
make_test_bin(197, 210, 52, 41, y_test, y_pred);
|
make_test_bin(197, 210, 52, 41, y_test, y_pred);
|
||||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
platform::Scores scores(y_test_tensor, y_pred, 2);
|
||||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 2);
|
|
||||||
y_test.clear();
|
y_test.clear();
|
||||||
y_pred.clear();
|
|
||||||
make_test_bin(227, 187, 39, 47, y_test, y_pred);
|
make_test_bin(227, 187, 39, 47, y_test, y_pred);
|
||||||
auto y_test_tensor2 = torch::tensor(y_test, torch::kInt32);
|
auto y_test_tensor2 = torch::tensor(y_test, torch::kInt32);
|
||||||
auto y_pred_tensor2 = torch::tensor(y_pred, torch::kInt32);
|
platform::Scores scores2(y_test_tensor2, y_pred, 2);
|
||||||
platform::Scores scores2(y_test_tensor2, y_pred_tensor2, 2);
|
|
||||||
scores.aggregate(scores2);
|
scores.aggregate(scores2);
|
||||||
REQUIRE(scores.accuracy() == Catch::Approx(0.821).epsilon(epsilon));
|
REQUIRE(scores.accuracy() == Catch::Approx(0.821).epsilon(epsilon));
|
||||||
REQUIRE(scores.f1_score(0) == Catch::Approx(0.8160329));
|
REQUIRE(scores.f1_score(0) == Catch::Approx(0.8160329));
|
||||||
@@ -195,11 +251,9 @@ TEST_CASE("Aggregate", "[Scores]")
|
|||||||
REQUIRE(scores.f1_weighted() == Catch::Approx(0.8209856));
|
REQUIRE(scores.f1_weighted() == Catch::Approx(0.8209856));
|
||||||
REQUIRE(scores.f1_macro() == Catch::Approx(0.8208694));
|
REQUIRE(scores.f1_macro() == Catch::Approx(0.8208694));
|
||||||
y_test.clear();
|
y_test.clear();
|
||||||
y_pred.clear();
|
|
||||||
make_test_bin(197 + 227, 210 + 187, 52 + 39, 41 + 47, y_test, y_pred);
|
make_test_bin(197 + 227, 210 + 187, 52 + 39, 41 + 47, y_test, y_pred);
|
||||||
y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||||
y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
platform::Scores scores3(y_test_tensor, y_pred, 2);
|
||||||
platform::Scores scores3(y_test_tensor, y_pred_tensor, 2);
|
|
||||||
for (int i = 0; i < 2; ++i) {
|
for (int i = 0; i < 2; ++i) {
|
||||||
REQUIRE(scores3.f1_score(i) == scores.f1_score(i));
|
REQUIRE(scores3.f1_score(i) == scores.f1_score(i));
|
||||||
REQUIRE(scores3.precision(i) == scores.precision(i));
|
REQUIRE(scores3.precision(i) == scores.precision(i));
|
||||||
@@ -212,11 +266,22 @@ TEST_CASE("Aggregate", "[Scores]")
|
|||||||
TEST_CASE("Order of keys", "[Scores]")
|
TEST_CASE("Order of keys", "[Scores]")
|
||||||
{
|
{
|
||||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
std::vector<std::array<double, 3>> probs = {
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 0.0, 1.0 }, // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 0.0, 1.0, 0.0 }, // P(class 0)=0, P(class 1)=1, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 1.0, 0.0, 0.0 }, // P(class 0)=1, P(class 1)=0, P(class 2)=0
|
||||||
|
{ 0.0, 0.0, 1.0 } // P(class 0)=0, P(class 1)=0, P(class 2)=1
|
||||||
|
};
|
||||||
|
torch::Tensor y_pred = torch::from_blob(probs.data(), { (long)probs.size(), 3 }, torch::kFloat64).clone();
|
||||||
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
auto y_test_tensor = torch::tensor(y_test, torch::kInt32);
|
||||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
|
||||||
std::vector<std::string> labels = { "Car", "Boat", "Aeroplane" };
|
std::vector<std::string> labels = { "Car", "Boat", "Aeroplane" };
|
||||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
platform::Scores scores(y_test_tensor, y_pred, 3, labels);
|
||||||
auto res_json_int = scores.get_confusion_matrix_json(true);
|
auto res_json_int = scores.get_confusion_matrix_json(true);
|
||||||
// Make a temp file and store the json
|
// Make a temp file and store the json
|
||||||
std::string filename = "temp.json";
|
std::string filename = "temp.json";
|
||||||
|
@@ -5,7 +5,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <ArffFiles.hpp>
|
#include <ArffFiles/ArffFiles.hpp>
|
||||||
#include <fimdlp/CPPFImdlp.h>
|
#include <fimdlp/CPPFImdlp.h>
|
||||||
|
|
||||||
bool file_exists(const std::string& name);
|
bool file_exists(const std::string& name);
|
||||||
|
@@ -7,6 +7,7 @@
|
|||||||
"fimdlp",
|
"fimdlp",
|
||||||
"libtorch-bin",
|
"libtorch-bin",
|
||||||
"folding",
|
"folding",
|
||||||
|
"catch2",
|
||||||
"argparse"
|
"argparse"
|
||||||
],
|
],
|
||||||
"overrides": [
|
"overrides": [
|
||||||
@@ -30,9 +31,13 @@
|
|||||||
"name": "argpase",
|
"name": "argpase",
|
||||||
"version": "3.2"
|
"version": "3.2"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "catch2",
|
||||||
|
"version": "3.8.1"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "nlohmann-json",
|
"name": "nlohmann-json",
|
||||||
"version": "3.11.3"
|
"version": "3.11.3"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user