AdaBoost a falta de predict_proba

This commit is contained in:
2025-06-18 13:59:23 +02:00
parent 415a7ae608
commit 56af1a5f85
4 changed files with 191 additions and 49 deletions

View File

@@ -13,11 +13,13 @@
#include <stdexcept>
#include "experimental_clfs/AdaBoost.h"
#include "experimental_clfs/DecisionTree.h"
#include "experimental_clfs/TensorUtils.hpp"
#include "TestUtils.h"
using namespace bayesnet;
using namespace Catch::Matchers;
TEST_CASE("AdaBoost Construction", "[AdaBoost]")
{
SECTION("Default constructor")
@@ -143,7 +145,15 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]")
auto predictions = ada.predict(X);
REQUIRE(predictions.size() == static_cast<size_t>(n_samples));
// Check accuracy
int correct = 0;
for (size_t i = 0; i < predictions.size(); i++) {
if (predictions[i] == y[i]) correct++;
}
double accuracy = static_cast<double>(correct) / n_samples;
REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset
auto accuracy_computed = ada.score(X, y);
REQUIRE(accuracy_computed == Catch::Approx(accuracy).epsilon(1e-6));
}
SECTION("Probability predictions with vector interface")
@@ -157,6 +167,7 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]")
// Check probabilities sum to 1 and are valid
auto predictions = ada.predict(X);
int correct = 0;
for (size_t i = 0; i < proba.size(); i++) {
auto p = proba[i];
auto pred = predictions[i];
@@ -165,10 +176,19 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]")
REQUIRE(p[1] >= 0.0);
double sum = p[0] + p[1];
REQUIRE(sum == Catch::Approx(1.0).epsilon(1e-6));
// compute the predicted class based on probabilities
auto predicted_class = (p[0] > p[1]) ? 0 : 1;
// compute accuracy based on predictions
if (predicted_class == y[i]) {
correct++;
}
// Check that predict_proba matches the expected predict value
REQUIRE(pred == (p[0] > p[1] ? 0 : 1));
// REQUIRE(pred == (p[0] > p[1] ? 0 : 1));
}
double accuracy = static_cast<double>(correct) / n_samples;
std::cout << "Probability accuracy: " << accuracy << std::endl;
REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset
}
}
@@ -194,7 +214,9 @@ TEST_CASE("AdaBoost Tensor Interface", "[AdaBoost]")
// Calculate accuracy
auto correct = torch::sum(predictions == raw.yt).item<int>();
double accuracy = static_cast<double>(correct) / raw.yt.size(0);
REQUIRE(accuracy > 0.85); // Should achieve good accuracy on Iris
auto accuracy_computed = ada.score(raw.Xt, raw.yt);
REQUIRE(accuracy_computed == Catch::Approx(accuracy).epsilon(1e-6));
REQUIRE(accuracy > 0.97); // Should achieve good accuracy on Iris
// Test probability predictions with tensor
auto proba = ada.predict_proba(raw.Xt);
@@ -704,4 +726,88 @@ TEST_CASE("AdaBoost SAMME Algorithm Validation", "[AdaBoost]")
REQUIRE_THROWS_WITH(ada.predict(X), ContainsSubstring("not been fitted"));
REQUIRE_THROWS_WITH(ada.predict_proba(X), ContainsSubstring("not been fitted"));
}
}
TEST_CASE("AdaBoost Predict-Proba Consistency Fix", "[AdaBoost][consistency]")
{
// Simple binary classification dataset
std::vector<std::vector<int>> X = { {0,0,1,1}, {0,1,0,1} };
std::vector<int> y = { 0, 0, 1, 1 };
std::vector<std::string> features = { "f1", "f2" };
std::string className = "class";
std::map<std::string, std::vector<int>> states;
states["f1"] = { 0, 1 };
states["f2"] = { 0, 1 };
states["class"] = { 0, 1 };
SECTION("Binary classification consistency")
{
AdaBoost ada(3, 2);
ada.setDebug(true); // Enable debug output
ada.fit(X, y, features, className, states, Smoothing_t::NONE);
auto predictions = ada.predict(X);
auto probabilities = ada.predict_proba(X);
INFO("=== Debugging predict vs predict_proba consistency ===");
// Verify consistency for each sample
for (size_t i = 0; i < predictions.size(); i++) {
int predicted_class = predictions[i];
auto probs = probabilities[i];
INFO("Sample " << i << ":");
INFO(" True class: " << y[i]);
INFO(" Predicted class: " << predicted_class);
INFO(" Probabilities: [" << probs[0] << ", " << probs[1] << "]");
// The predicted class should be the one with highest probability
int max_prob_class = (probs[0] > probs[1]) ? 0 : 1;
INFO(" Max prob class: " << max_prob_class);
REQUIRE(predicted_class == max_prob_class);
// Probabilities should sum to 1
double sum_probs = probs[0] + probs[1];
REQUIRE(sum_probs == Catch::Approx(1.0).epsilon(1e-6));
// All probabilities should be valid
REQUIRE(probs[0] >= 0.0);
REQUIRE(probs[1] >= 0.0);
REQUIRE(probs[0] <= 1.0);
REQUIRE(probs[1] <= 1.0);
}
}
SECTION("Multi-class consistency")
{
auto raw = RawDatasets("iris", true);
AdaBoost ada(5, 2);
ada.fit(raw.dataset, raw.featurest, raw.classNamet, raw.statest, Smoothing_t::NONE);
auto predictions = ada.predict(raw.Xt);
auto probabilities = ada.predict_proba(raw.Xt);
// Check consistency for first 10 samples
for (int i = 0; i < std::min(static_cast<int64_t>(10), predictions.size(0)); i++) {
int predicted_class = predictions[i].item<int>();
auto probs = probabilities[i];
// Find class with maximum probability
auto max_prob_idx = torch::argmax(probs).item<int>();
INFO("Sample " << i << ":");
INFO(" Predicted class: " << predicted_class);
INFO(" Max prob class: " << max_prob_idx);
INFO(" Probabilities: [" << probs[0].item<float>() << ", "
<< probs[1].item<float>() << ", " << probs[2].item<float>() << "]");
// They must match
REQUIRE(predicted_class == max_prob_idx);
// Probabilities should sum to 1
double sum_probs = torch::sum(probs).item<double>();
REQUIRE(sum_probs == Catch::Approx(1.0).epsilon(1e-6));
}
}
}