diff --git a/docs/BoostAODE.docx b/docs/BoostAODE.docx new file mode 100644 index 0000000..ec04a70 Binary files /dev/null and b/docs/BoostAODE.docx differ diff --git a/lib/argparse b/lib/argparse new file mode 160000 index 0000000..69dabd8 --- /dev/null +++ b/lib/argparse @@ -0,0 +1 @@ +Subproject commit 69dabd88a8e6680b1a1a18397eb3e165e4019ce6 diff --git a/lib/catch2 b/lib/catch2 index 863c662..766541d 160000 --- a/lib/catch2 +++ b/lib/catch2 @@ -1 +1 @@ -Subproject commit 863c662c0eff026300f4d729a7054e90d6d12cdd +Subproject commit 766541d12d64845f5232a1ce4e34a85e83506b09 diff --git a/lib/libxlsxwriter b/lib/libxlsxwriter new file mode 160000 index 0000000..29355a0 --- /dev/null +++ b/lib/libxlsxwriter @@ -0,0 +1 @@ +Subproject commit 29355a0887475488c7cc470ad43cc867fcfa92e2 diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index efb95e6..3b1d0f1 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -121,6 +121,8 @@ namespace bayesnet { } void BoostAODE::trainModel(const torch::Tensor& weights) { + // Algorithm based on the adaboost algorithm for classification + // as explained in Ensemble methods (Zhi-Hua Zhou, 2012) std::unordered_set featuresUsed; if (selectFeatures) { featuresUsed = initializeModels(); @@ -132,9 +134,8 @@ namespace bayesnet { // Variables to control the accuracy finish condition double priorAccuracy = 0.0; double delta = 1.0; - double threshold = 1e-4; - int count = 0; // number of times the accuracy is lower than the threshold - fitted = true; // to enable predict + double convergence_threshold = 1e-4; + int count = 0; // number of times the accuracy is lower than the convergence_threshold // Step 0: Set the finish condition // if not repeatSparent a finish condition is run out of features // n_models == maxModels @@ -191,10 +192,12 @@ namespace bayesnet { } else { delta = accuracy - priorAccuracy; } - if (delta < threshold) { + if (delta < convergence_threshold) { count++; } + priorAccuracy = accuracy; } + // epsilon_t > 0.5 => inverse the weights policy (plot ln(wt)) exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance; } if (featuresUsed.size() != features.size()) { @@ -202,6 +205,7 @@ namespace bayesnet { status = WARNING; } notes.push_back("Number of models: " + std::to_string(n_models)); + fitted = true; } std::vector BoostAODE::graph(const std::string& title) const {