Add comments to BoostAODE algorithm

This commit is contained in:
2024-02-19 22:58:15 +01:00
parent f3b8150e2c
commit c7555dac3f
5 changed files with 11 additions and 5 deletions

BIN
docs/BoostAODE.docx Normal file

Binary file not shown.

1
lib/argparse Submodule

Submodule lib/argparse added at 69dabd88a8

1
lib/libxlsxwriter Submodule

Submodule lib/libxlsxwriter added at 29355a0887

View File

@@ -121,6 +121,8 @@ namespace bayesnet {
} }
void BoostAODE::trainModel(const torch::Tensor& weights) void BoostAODE::trainModel(const torch::Tensor& weights)
{ {
// Algorithm based on the adaboost algorithm for classification
// as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
std::unordered_set<int> featuresUsed; std::unordered_set<int> featuresUsed;
if (selectFeatures) { if (selectFeatures) {
featuresUsed = initializeModels(); featuresUsed = initializeModels();
@@ -132,9 +134,8 @@ namespace bayesnet {
// Variables to control the accuracy finish condition // Variables to control the accuracy finish condition
double priorAccuracy = 0.0; double priorAccuracy = 0.0;
double delta = 1.0; double delta = 1.0;
double threshold = 1e-4; double convergence_threshold = 1e-4;
int count = 0; // number of times the accuracy is lower than the threshold int count = 0; // number of times the accuracy is lower than the convergence_threshold
fitted = true; // to enable predict
// Step 0: Set the finish condition // Step 0: Set the finish condition
// if not repeatSparent a finish condition is run out of features // if not repeatSparent a finish condition is run out of features
// n_models == maxModels // n_models == maxModels
@@ -191,10 +192,12 @@ namespace bayesnet {
} else { } else {
delta = accuracy - priorAccuracy; delta = accuracy - priorAccuracy;
} }
if (delta < threshold) { if (delta < convergence_threshold) {
count++; count++;
} }
priorAccuracy = accuracy;
} }
// epsilon_t > 0.5 => inverse the weights policy (plot ln(wt))
exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance; exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
} }
if (featuresUsed.size() != features.size()) { if (featuresUsed.size() != features.size()) {
@@ -202,6 +205,7 @@ namespace bayesnet {
status = WARNING; status = WARNING;
} }
notes.push_back("Number of models: " + std::to_string(n_models)); notes.push_back("Number of models: " + std::to_string(n_models));
fitted = true;
} }
std::vector<std::string> BoostAODE::graph(const std::string& title) const std::vector<std::string> BoostAODE::graph(const std::string& title) const
{ {