Add comments to BoostAODE algorithm
This commit is contained in:
BIN
docs/BoostAODE.docx
Normal file
BIN
docs/BoostAODE.docx
Normal file
Binary file not shown.
1
lib/argparse
Submodule
1
lib/argparse
Submodule
Submodule lib/argparse added at 69dabd88a8
Submodule lib/catch2 updated: 863c662c0e...766541d12d
1
lib/libxlsxwriter
Submodule
1
lib/libxlsxwriter
Submodule
Submodule lib/libxlsxwriter added at 29355a0887
@@ -121,6 +121,8 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
|
// Algorithm based on the adaboost algorithm for classification
|
||||||
|
// as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
|
||||||
std::unordered_set<int> featuresUsed;
|
std::unordered_set<int> featuresUsed;
|
||||||
if (selectFeatures) {
|
if (selectFeatures) {
|
||||||
featuresUsed = initializeModels();
|
featuresUsed = initializeModels();
|
||||||
@@ -132,9 +134,8 @@ namespace bayesnet {
|
|||||||
// Variables to control the accuracy finish condition
|
// Variables to control the accuracy finish condition
|
||||||
double priorAccuracy = 0.0;
|
double priorAccuracy = 0.0;
|
||||||
double delta = 1.0;
|
double delta = 1.0;
|
||||||
double threshold = 1e-4;
|
double convergence_threshold = 1e-4;
|
||||||
int count = 0; // number of times the accuracy is lower than the threshold
|
int count = 0; // number of times the accuracy is lower than the convergence_threshold
|
||||||
fitted = true; // to enable predict
|
|
||||||
// Step 0: Set the finish condition
|
// Step 0: Set the finish condition
|
||||||
// if not repeatSparent a finish condition is run out of features
|
// if not repeatSparent a finish condition is run out of features
|
||||||
// n_models == maxModels
|
// n_models == maxModels
|
||||||
@@ -191,10 +192,12 @@ namespace bayesnet {
|
|||||||
} else {
|
} else {
|
||||||
delta = accuracy - priorAccuracy;
|
delta = accuracy - priorAccuracy;
|
||||||
}
|
}
|
||||||
if (delta < threshold) {
|
if (delta < convergence_threshold) {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
priorAccuracy = accuracy;
|
||||||
}
|
}
|
||||||
|
// epsilon_t > 0.5 => inverse the weights policy (plot ln(wt))
|
||||||
exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
|
exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
|
||||||
}
|
}
|
||||||
if (featuresUsed.size() != features.size()) {
|
if (featuresUsed.size() != features.size()) {
|
||||||
@@ -202,6 +205,7 @@ namespace bayesnet {
|
|||||||
status = WARNING;
|
status = WARNING;
|
||||||
}
|
}
|
||||||
notes.push_back("Number of models: " + std::to_string(n_models));
|
notes.push_back("Number of models: " + std::to_string(n_models));
|
||||||
|
fitted = true;
|
||||||
}
|
}
|
||||||
std::vector<std::string> BoostAODE::graph(const std::string& title) const
|
std::vector<std::string> BoostAODE::graph(const std::string& title) const
|
||||||
{
|
{
|
||||||
|
Reference in New Issue
Block a user