From 41afa1b8883b68c8d531d4c06bc2237f508dc411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 18 Jun 2025 17:33:56 +0200 Subject: [PATCH] Enhance predictProbaSample --- src/experimental_clfs/AdaBoost.cpp | 73 +++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 22 deletions(-) diff --git a/src/experimental_clfs/AdaBoost.cpp b/src/experimental_clfs/AdaBoost.cpp index 407c587..3e276a2 100644 --- a/src/experimental_clfs/AdaBoost.cpp +++ b/src/experimental_clfs/AdaBoost.cpp @@ -412,52 +412,81 @@ namespace bayesnet { std::to_string(n) + " but got " + std::to_string(x.size(0))); } - // Initialize class votes (same logic as predictSample) + // Initialize class votes with zeros std::vector class_votes(n_classes, 0.0); + double total_votes = 0.0; - // Accumulate weighted votes from all estimators (SAMME voting) - double total_alpha = 0.0; + if (debug) { + std::cout << "=== predictProbaSample Debug ===" << std::endl; + std::cout << "Number of models: " << models.size() << std::endl; + std::cout << "Number of classes: " << n_classes << std::endl; + } + + // Accumulate votes from all estimators for (size_t i = 0; i < models.size(); i++) { - if (alphas[i] <= 0) continue; // Skip estimators with zero or negative weight + double alpha = alphas[i]; + + // Skip invalid estimators + if (alpha <= 0 || !std::isfinite(alpha)) { + if (debug) std::cout << "Skipping model " << i << " (alpha=" << alpha << ")" << std::endl; + continue; + } try { - // Get class prediction from this estimator (not probabilities!) + // Get class prediction from this estimator int predicted_class = static_cast(models[i].get())->predictSample(x); - // Add weighted vote for this class (SAMME algorithm) + if (debug) { + std::cout << "Model " << i << ": predicts class " << predicted_class + << " with alpha " << alpha << std::endl; + } + + // Add weighted vote for this class if (predicted_class >= 0 && predicted_class < n_classes) { - class_votes[predicted_class] += alphas[i]; - total_alpha += alphas[i]; + class_votes[predicted_class] += alpha; + total_votes += alpha; + } else { + if (debug) std::cout << "Invalid class prediction: " << predicted_class << std::endl; } } catch (const std::exception& e) { - std::cerr << "Error in estimator " << i << ": " << e.what() << std::endl; + if (debug) std::cout << "Error in model " << i << ": " << e.what() << std::endl; continue; } } + if (debug) { + std::cout << "Total votes: " << total_votes << std::endl; + std::cout << "Class votes: ["; + for (int j = 0; j < n_classes; j++) { + std::cout << class_votes[j]; + if (j < n_classes - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + } + // Convert votes to probabilities torch::Tensor class_probs = torch::zeros({ n_classes }, torch::kFloat); - if (total_alpha > 0) { - // Normalize votes to get probabilities + if (total_votes > 0) { + // Simple division to get probabilities for (int j = 0; j < n_classes; j++) { - class_probs[j] = static_cast(class_votes[j] / total_alpha); + class_probs[j] = static_cast(class_votes[j] / total_votes); } } else { - // If no valid estimators, return uniform distribution + // If no valid votes, uniform distribution + if (debug) std::cout << "No valid votes, using uniform distribution" << std::endl; class_probs.fill_(1.0f / n_classes); } - // Ensure probabilities are valid (they should be already, but just in case) - class_probs = torch::clamp(class_probs, 0.0f, 1.0f); - - // Verify they sum to 1 (they should, but normalize if needed due to floating point errors) - float sum_probs = torch::sum(class_probs).item(); - if (sum_probs > 1e-15f) { - class_probs = class_probs / sum_probs; - } else { - class_probs.fill_(1.0f / n_classes); + if (debug) { + std::cout << "Final probabilities: ["; + for (int j = 0; j < n_classes; j++) { + std::cout << class_probs[j].item(); + if (j < n_classes - 1) std::cout << ", "; + } + std::cout << "]" << std::endl; + std::cout << "=== End predictProbaSample Debug ===" << std::endl; } return class_probs;