Compare commits

...

2 Commits

8 changed files with 33 additions and 12 deletions

6
.vscode/launch.json vendored
View File

@@ -106,12 +106,12 @@
"type": "lldb", "type": "lldb",
"request": "launch", "request": "launch",
"name": "test", "name": "test",
"program": "${workspaceFolder}/build_debug/tests/unit_tests", "program": "${workspaceFolder}/build_debug/tests/unit_tests_bayesnet",
"args": [ "args": [
"-c=\"Metrics Test\"", //"-c=\"Metrics Test\"",
// "-s", // "-s",
], ],
"cwd": "${workspaceFolder}/build/tests", "cwd": "${workspaceFolder}/build_debug/tests",
}, },
{ {
"name": "Build & debug active file", "name": "Build & debug active file",

View File

@@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [1.0.2] - 2024-02-20
### Fixed
- Fix bug in BoostAODE: do not include the model if epsilon sub t is greater than 0.5
- Fix bug in BoostAODE: compare accuracy with previous accuracy instead of the first of the ensemble if convergence true
## [1.0.1] - 2024-02-12 ## [1.0.1] - 2024-02-12
### Added ### Added

View File

@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.20) cmake_minimum_required(VERSION 3.20)
project(BayesNet project(BayesNet
VERSION 1.0.1 VERSION 1.0.2
DESCRIPTION "Bayesian Network and basic classifiers Library." DESCRIPTION "Bayesian Network and basic classifiers Library."
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet" HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
LANGUAGES CXX LANGUAGES CXX

BIN
docs/BoostAODE.docx Normal file

Binary file not shown.

1
lib/argparse Submodule

Submodule lib/argparse added at 69dabd88a8

1
lib/libxlsxwriter Submodule

Submodule lib/libxlsxwriter added at 29355a0887

View File

@@ -121,6 +121,9 @@ namespace bayesnet {
} }
void BoostAODE::trainModel(const torch::Tensor& weights) void BoostAODE::trainModel(const torch::Tensor& weights)
{ {
fitted = true;
// Algorithm based on the adaboost algorithm for classification
// as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
std::unordered_set<int> featuresUsed; std::unordered_set<int> featuresUsed;
if (selectFeatures) { if (selectFeatures) {
featuresUsed = initializeModels(); featuresUsed = initializeModels();
@@ -132,9 +135,8 @@ namespace bayesnet {
// Variables to control the accuracy finish condition // Variables to control the accuracy finish condition
double priorAccuracy = 0.0; double priorAccuracy = 0.0;
double delta = 1.0; double delta = 1.0;
double threshold = 1e-4; double convergence_threshold = 1e-4;
int count = 0; // number of times the accuracy is lower than the threshold int count = 0; // number of times the accuracy is lower than the convergence_threshold
fitted = true; // to enable predict
// Step 0: Set the finish condition // Step 0: Set the finish condition
// if not repeatSparent a finish condition is run out of features // if not repeatSparent a finish condition is run out of features
// n_models == maxModels // n_models == maxModels
@@ -160,7 +162,6 @@ namespace bayesnet {
continue; continue;
} }
} }
featuresUsed.insert(feature);
model = std::make_unique<SPODE>(feature); model = std::make_unique<SPODE>(feature);
model->fit(dataset, features, className, states, weights_); model->fit(dataset, features, className, states, weights_);
auto ypred = model->predict(X_train); auto ypred = model->predict(X_train);
@@ -169,6 +170,12 @@ namespace bayesnet {
auto mask_right = ypred == y_train; auto mask_right = ypred == y_train;
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype()); auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
double epsilon_t = masked_weights.sum().item<double>(); double epsilon_t = masked_weights.sum().item<double>();
if (epsilon_t > 0.5) {
// Inverse the weights policy (plot ln(wt))
// "In each round of AdaBoost, there is a sanity check to ensure that the current base
// learner is better than random guess" (Zhi-Hua Zhou, 2012)
break;
}
double wt = (1 - epsilon_t) / epsilon_t; double wt = (1 - epsilon_t) / epsilon_t;
double alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt); double alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt);
// Step 3.2: Update weights for next classifier // Step 3.2: Update weights for next classifier
@@ -180,6 +187,7 @@ namespace bayesnet {
double totalWeights = torch::sum(weights_).item<double>(); double totalWeights = torch::sum(weights_).item<double>();
weights_ = weights_ / totalWeights; weights_ = weights_ / totalWeights;
// Step 3.4: Store classifier and its accuracy to weigh its future vote // Step 3.4: Store classifier and its accuracy to weigh its future vote
featuresUsed.insert(feature);
models.push_back(std::move(model)); models.push_back(std::move(model));
significanceModels.push_back(alpha_t); significanceModels.push_back(alpha_t);
n_models++; n_models++;
@@ -191,11 +199,12 @@ namespace bayesnet {
} else { } else {
delta = accuracy - priorAccuracy; delta = accuracy - priorAccuracy;
} }
if (delta < threshold) { if (delta < convergence_threshold) {
count++; count++;
} }
priorAccuracy = accuracy;
} }
exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance; exitCondition = n_models >= maxModels && repeatSparent || count > tolerance;
} }
if (featuresUsed.size() != features.size()) { if (featuresUsed.size() != features.size()) {
notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size())); notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()));

View File

@@ -19,7 +19,7 @@
TEST_CASE("Library check version", "[BayesNet]") TEST_CASE("Library check version", "[BayesNet]")
{ {
auto clf = bayesnet::KDB(2); auto clf = bayesnet::KDB(2);
REQUIRE(clf.getVersion() == "1.0.1"); REQUIRE(clf.getVersion() == "1.0.2");
} }
TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]") TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
{ {
@@ -164,7 +164,8 @@ TEST_CASE("BoostAODE test used features in train note", "[BayesNet]")
{"ascending",true}, {"ascending",true},
{"convergence", true}, {"convergence", true},
{"repeatSparent",true}, {"repeatSparent",true},
{"select_features","CFS"} {"select_features","CFS"},
{"tolerance", 3}
}); });
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
REQUIRE(clf.getNumberOfNodes() == 72); REQUIRE(clf.getNumberOfNodes() == 72);