diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp index 5af7f31..407c587 100644 --- a/src/experimental_clfs/AdaBoost.cpp +++ b/src/experimental_clfs/AdaBoost.cpp @@ -74,32 +74,53 @@ namespace bayesnet { break; // Stop boosting } + // Check for perfect classification BEFORE calculating alpha + if (weighted_error <= 1e-10) { + if (debug) std::cout << " Perfect classification achieved (error=" << weighted_error << ")" << std::endl; + + // For perfect classification, use a large but finite alpha + double alpha = 10.0 + std::log(static_cast(n_classes - 1)); + + // Store the estimator and its weight + models.push_back(std::move(estimator)); + alphas.push_back(alpha); + + if (debug) { + std::cout << "Iteration " << iter << ":" << std::endl; + std::cout << " Weighted error: " << weighted_error << std::endl; + std::cout << " Alpha (finite): " << alpha << std::endl; + std::cout << " Random guess error: " << random_guess_error << std::endl; + } + + break; // Stop training as we have a perfect classifier + } + // Calculate alpha (estimator weight) using SAMME formula // alpha = log((1 - err) / err) + log(K - 1) - double alpha = std::log((1.0 - weighted_error) / weighted_error) + + // Clamp weighted_error to avoid division by zero and infinite alpha + double clamped_error = std::max(1e-15, std::min(1.0 - 1e-15, weighted_error)); + double alpha = std::log((1.0 - clamped_error) / clamped_error) + std::log(static_cast(n_classes - 1)); + // Clamp alpha to reasonable bounds to avoid numerical issues + alpha = std::max(-10.0, std::min(10.0, alpha)); + // Store the estimator and its weight models.push_back(std::move(estimator)); alphas.push_back(alpha); - // Update sample weights - updateSampleWeights(models.back().get(), alpha); - - // Normalize weights - normalizeWeights(); + // Update sample weights (only if this is not the last iteration) + if (iter < n_estimators - 1) { + updateSampleWeights(models.back().get(), alpha); + normalizeWeights(); + } if (debug) { std::cout << "Iteration " << iter << ":" << std::endl; std::cout << " Weighted error: " << weighted_error << std::endl; std::cout << " Alpha: " << alpha << std::endl; std::cout << " Random guess error: " << random_guess_error << std::endl; - } - - // Check for perfect classification - if (weighted_error < 1e-10) { - if (debug) std::cout << " Perfect classification achieved, stopping" << std::endl; - break; + std::cout << " Random guess error: " << random_guess_error << std::endl; } } diff --git a/tests/TestAdaBoost.cpp b/tests/TestAdaBoost.cpp index 301ebb2..7fb887b 100644 --- a/tests/TestAdaBoost.cpp +++ b/tests/TestAdaBoost.cpp @@ -184,10 +184,9 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]") } // Check that predict_proba matches the expected predict value - // REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); + REQUIRE(pred == (p[0] > p[1] ? 0 : 1)); } double accuracy = static_cast(correct) / n_samples; - std::cout << "Probability accuracy: " << accuracy << std::endl; REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset } } @@ -711,6 +710,8 @@ TEST_CASE("AdaBoost SAMME Algorithm Validation", "[AdaBoost]") for (size_t i = 0; i < predictions.size(); i++) { int pred = predictions[i]; auto probs = probabilities[i]; + INFO("Sample " << i << ": predicted=" << pred + << ", probabilities=[" << probs[0] << ", " << probs[1] << "]"); REQUIRE(pred == (probs[0] > probs[1] ? 0 : 1)); REQUIRE(probs[0] + probs[1] == Catch::Approx(1.0).epsilon(1e-6));