Fix predict_proba in AdaBoost
This commit is contained in:
@@ -74,32 +74,53 @@ namespace bayesnet {
|
|||||||
break; // Stop boosting
|
break; // Stop boosting
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for perfect classification BEFORE calculating alpha
|
||||||
|
if (weighted_error <= 1e-10) {
|
||||||
|
if (debug) std::cout << " Perfect classification achieved (error=" << weighted_error << ")" << std::endl;
|
||||||
|
|
||||||
|
// For perfect classification, use a large but finite alpha
|
||||||
|
double alpha = 10.0 + std::log(static_cast<double>(n_classes - 1));
|
||||||
|
|
||||||
|
// Store the estimator and its weight
|
||||||
|
models.push_back(std::move(estimator));
|
||||||
|
alphas.push_back(alpha);
|
||||||
|
|
||||||
|
if (debug) {
|
||||||
|
std::cout << "Iteration " << iter << ":" << std::endl;
|
||||||
|
std::cout << " Weighted error: " << weighted_error << std::endl;
|
||||||
|
std::cout << " Alpha (finite): " << alpha << std::endl;
|
||||||
|
std::cout << " Random guess error: " << random_guess_error << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
break; // Stop training as we have a perfect classifier
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate alpha (estimator weight) using SAMME formula
|
// Calculate alpha (estimator weight) using SAMME formula
|
||||||
// alpha = log((1 - err) / err) + log(K - 1)
|
// alpha = log((1 - err) / err) + log(K - 1)
|
||||||
double alpha = std::log((1.0 - weighted_error) / weighted_error) +
|
// Clamp weighted_error to avoid division by zero and infinite alpha
|
||||||
|
double clamped_error = std::max(1e-15, std::min(1.0 - 1e-15, weighted_error));
|
||||||
|
double alpha = std::log((1.0 - clamped_error) / clamped_error) +
|
||||||
std::log(static_cast<double>(n_classes - 1));
|
std::log(static_cast<double>(n_classes - 1));
|
||||||
|
|
||||||
|
// Clamp alpha to reasonable bounds to avoid numerical issues
|
||||||
|
alpha = std::max(-10.0, std::min(10.0, alpha));
|
||||||
|
|
||||||
// Store the estimator and its weight
|
// Store the estimator and its weight
|
||||||
models.push_back(std::move(estimator));
|
models.push_back(std::move(estimator));
|
||||||
alphas.push_back(alpha);
|
alphas.push_back(alpha);
|
||||||
|
|
||||||
// Update sample weights
|
// Update sample weights (only if this is not the last iteration)
|
||||||
updateSampleWeights(models.back().get(), alpha);
|
if (iter < n_estimators - 1) {
|
||||||
|
updateSampleWeights(models.back().get(), alpha);
|
||||||
// Normalize weights
|
normalizeWeights();
|
||||||
normalizeWeights();
|
}
|
||||||
|
|
||||||
if (debug) {
|
if (debug) {
|
||||||
std::cout << "Iteration " << iter << ":" << std::endl;
|
std::cout << "Iteration " << iter << ":" << std::endl;
|
||||||
std::cout << " Weighted error: " << weighted_error << std::endl;
|
std::cout << " Weighted error: " << weighted_error << std::endl;
|
||||||
std::cout << " Alpha: " << alpha << std::endl;
|
std::cout << " Alpha: " << alpha << std::endl;
|
||||||
std::cout << " Random guess error: " << random_guess_error << std::endl;
|
std::cout << " Random guess error: " << random_guess_error << std::endl;
|
||||||
}
|
std::cout << " Random guess error: " << random_guess_error << std::endl;
|
||||||
|
|
||||||
// Check for perfect classification
|
|
||||||
if (weighted_error < 1e-10) {
|
|
||||||
if (debug) std::cout << " Perfect classification achieved, stopping" << std::endl;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -184,10 +184,9 @@ TEST_CASE("AdaBoost Basic Functionality", "[AdaBoost]")
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check that predict_proba matches the expected predict value
|
// Check that predict_proba matches the expected predict value
|
||||||
// REQUIRE(pred == (p[0] > p[1] ? 0 : 1));
|
REQUIRE(pred == (p[0] > p[1] ? 0 : 1));
|
||||||
}
|
}
|
||||||
double accuracy = static_cast<double>(correct) / n_samples;
|
double accuracy = static_cast<double>(correct) / n_samples;
|
||||||
std::cout << "Probability accuracy: " << accuracy << std::endl;
|
|
||||||
REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset
|
REQUIRE(accuracy > 0.99); // Should achieve good accuracy on this simple dataset
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -711,6 +710,8 @@ TEST_CASE("AdaBoost SAMME Algorithm Validation", "[AdaBoost]")
|
|||||||
for (size_t i = 0; i < predictions.size(); i++) {
|
for (size_t i = 0; i < predictions.size(); i++) {
|
||||||
int pred = predictions[i];
|
int pred = predictions[i];
|
||||||
auto probs = probabilities[i];
|
auto probs = probabilities[i];
|
||||||
|
INFO("Sample " << i << ": predicted=" << pred
|
||||||
|
<< ", probabilities=[" << probs[0] << ", " << probs[1] << "]");
|
||||||
|
|
||||||
REQUIRE(pred == (probs[0] > probs[1] ? 0 : 1));
|
REQUIRE(pred == (probs[0] > probs[1] ? 0 : 1));
|
||||||
REQUIRE(probs[0] + probs[1] == Catch::Approx(1.0).epsilon(1e-6));
|
REQUIRE(probs[0] + probs[1] == Catch::Approx(1.0).epsilon(1e-6));
|
||||||
|
Reference in New Issue
Block a user