Enhance predictProbaSample
This commit is contained in:
@@ -412,52 +412,81 @@ namespace bayesnet {
|
||||
std::to_string(n) + " but got " + std::to_string(x.size(0)));
|
||||
}
|
||||
|
||||
// Initialize class votes (same logic as predictSample)
|
||||
// Initialize class votes with zeros
|
||||
std::vector<double> class_votes(n_classes, 0.0);
|
||||
double total_votes = 0.0;
|
||||
|
||||
// Accumulate weighted votes from all estimators (SAMME voting)
|
||||
double total_alpha = 0.0;
|
||||
if (debug) {
|
||||
std::cout << "=== predictProbaSample Debug ===" << std::endl;
|
||||
std::cout << "Number of models: " << models.size() << std::endl;
|
||||
std::cout << "Number of classes: " << n_classes << std::endl;
|
||||
}
|
||||
|
||||
// Accumulate votes from all estimators
|
||||
for (size_t i = 0; i < models.size(); i++) {
|
||||
if (alphas[i] <= 0) continue; // Skip estimators with zero or negative weight
|
||||
double alpha = alphas[i];
|
||||
|
||||
// Skip invalid estimators
|
||||
if (alpha <= 0 || !std::isfinite(alpha)) {
|
||||
if (debug) std::cout << "Skipping model " << i << " (alpha=" << alpha << ")" << std::endl;
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
// Get class prediction from this estimator (not probabilities!)
|
||||
// Get class prediction from this estimator
|
||||
int predicted_class = static_cast<DecisionTree*>(models[i].get())->predictSample(x);
|
||||
|
||||
// Add weighted vote for this class (SAMME algorithm)
|
||||
if (debug) {
|
||||
std::cout << "Model " << i << ": predicts class " << predicted_class
|
||||
<< " with alpha " << alpha << std::endl;
|
||||
}
|
||||
|
||||
// Add weighted vote for this class
|
||||
if (predicted_class >= 0 && predicted_class < n_classes) {
|
||||
class_votes[predicted_class] += alphas[i];
|
||||
total_alpha += alphas[i];
|
||||
class_votes[predicted_class] += alpha;
|
||||
total_votes += alpha;
|
||||
} else {
|
||||
if (debug) std::cout << "Invalid class prediction: " << predicted_class << std::endl;
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::cerr << "Error in estimator " << i << ": " << e.what() << std::endl;
|
||||
if (debug) std::cout << "Error in model " << i << ": " << e.what() << std::endl;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (debug) {
|
||||
std::cout << "Total votes: " << total_votes << std::endl;
|
||||
std::cout << "Class votes: [";
|
||||
for (int j = 0; j < n_classes; j++) {
|
||||
std::cout << class_votes[j];
|
||||
if (j < n_classes - 1) std::cout << ", ";
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
|
||||
// Convert votes to probabilities
|
||||
torch::Tensor class_probs = torch::zeros({ n_classes }, torch::kFloat);
|
||||
|
||||
if (total_alpha > 0) {
|
||||
// Normalize votes to get probabilities
|
||||
if (total_votes > 0) {
|
||||
// Simple division to get probabilities
|
||||
for (int j = 0; j < n_classes; j++) {
|
||||
class_probs[j] = static_cast<float>(class_votes[j] / total_alpha);
|
||||
class_probs[j] = static_cast<float>(class_votes[j] / total_votes);
|
||||
}
|
||||
} else {
|
||||
// If no valid estimators, return uniform distribution
|
||||
// If no valid votes, uniform distribution
|
||||
if (debug) std::cout << "No valid votes, using uniform distribution" << std::endl;
|
||||
class_probs.fill_(1.0f / n_classes);
|
||||
}
|
||||
|
||||
// Ensure probabilities are valid (they should be already, but just in case)
|
||||
class_probs = torch::clamp(class_probs, 0.0f, 1.0f);
|
||||
|
||||
// Verify they sum to 1 (they should, but normalize if needed due to floating point errors)
|
||||
float sum_probs = torch::sum(class_probs).item<float>();
|
||||
if (sum_probs > 1e-15f) {
|
||||
class_probs = class_probs / sum_probs;
|
||||
} else {
|
||||
class_probs.fill_(1.0f / n_classes);
|
||||
if (debug) {
|
||||
std::cout << "Final probabilities: [";
|
||||
for (int j = 0; j < n_classes; j++) {
|
||||
std::cout << class_probs[j].item<float>();
|
||||
if (j < n_classes - 1) std::cout << ", ";
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
std::cout << "=== End predictProbaSample Debug ===" << std::endl;
|
||||
}
|
||||
|
||||
return class_probs;
|
||||
|
Reference in New Issue
Block a user