Fix epsilont early stopping in BoostAODE
This commit is contained in:
@@ -121,6 +121,7 @@ namespace bayesnet {
|
||||
}
|
||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||
{
|
||||
fitted = true;
|
||||
// Algorithm based on the adaboost algorithm for classification
|
||||
// as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
|
||||
std::unordered_set<int> featuresUsed;
|
||||
@@ -161,7 +162,6 @@ namespace bayesnet {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
featuresUsed.insert(feature);
|
||||
model = std::make_unique<SPODE>(feature);
|
||||
model->fit(dataset, features, className, states, weights_);
|
||||
auto ypred = model->predict(X_train);
|
||||
@@ -170,6 +170,12 @@ namespace bayesnet {
|
||||
auto mask_right = ypred == y_train;
|
||||
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
|
||||
double epsilon_t = masked_weights.sum().item<double>();
|
||||
if (epsilon_t > 0.5) {
|
||||
// Inverse the weights policy (plot ln(wt))
|
||||
// "In each round of AdaBoost, there is a sanity check to ensure that the current base
|
||||
// learner is better than random guess" (Zhi-Hua Zhou, 2012)
|
||||
break;
|
||||
}
|
||||
double wt = (1 - epsilon_t) / epsilon_t;
|
||||
double alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt);
|
||||
// Step 3.2: Update weights for next classifier
|
||||
@@ -181,6 +187,7 @@ namespace bayesnet {
|
||||
double totalWeights = torch::sum(weights_).item<double>();
|
||||
weights_ = weights_ / totalWeights;
|
||||
// Step 3.4: Store classifier and its accuracy to weigh its future vote
|
||||
featuresUsed.insert(feature);
|
||||
models.push_back(std::move(model));
|
||||
significanceModels.push_back(alpha_t);
|
||||
n_models++;
|
||||
@@ -197,15 +204,13 @@ namespace bayesnet {
|
||||
}
|
||||
priorAccuracy = accuracy;
|
||||
}
|
||||
// epsilon_t > 0.5 => inverse the weights policy (plot ln(wt))
|
||||
exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
|
||||
exitCondition = n_models >= maxModels && repeatSparent || count > tolerance;
|
||||
}
|
||||
if (featuresUsed.size() != features.size()) {
|
||||
notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()));
|
||||
status = WARNING;
|
||||
}
|
||||
notes.push_back("Number of models: " + std::to_string(n_models));
|
||||
fitted = true;
|
||||
}
|
||||
std::vector<std::string> BoostAODE::graph(const std::string& title) const
|
||||
{
|
||||
|
Reference in New Issue
Block a user