From a63a35df3f4b71b4a1a5089ab664998066b560a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 20 Feb 2024 10:11:22 +0100 Subject: [PATCH] Fix epsilont early stopping in BoostAODE --- .vscode/launch.json | 6 +++--- CHANGELOG.md | 9 +++++++++ CMakeLists.txt | 2 +- lib/catch2 | 2 +- src/BayesNet/BoostAODE.cc | 13 +++++++++---- tests/TestBayesModels.cc | 5 +++-- 6 files changed, 26 insertions(+), 11 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 1e30c2d..a384091 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -106,12 +106,12 @@ "type": "lldb", "request": "launch", "name": "test", - "program": "${workspaceFolder}/build_debug/tests/unit_tests", + "program": "${workspaceFolder}/build_debug/tests/unit_tests_bayesnet", "args": [ - "-c=\"Metrics Test\"", + //"-c=\"Metrics Test\"", // "-s", ], - "cwd": "${workspaceFolder}/build/tests", + "cwd": "${workspaceFolder}/build_debug/tests", }, { "name": "Build & debug active file", diff --git a/CHANGELOG.md b/CHANGELOG.md index cc01c84..3fa6c5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +## [1.0.2] - 2024-02-20 + +### Fixed + +- Fix bug in BoostAODE: do not include the model if epsilon sub t is greater than 0.5 +- Fix bug in BoostAODE: compare accuracy with previous accuracy instead of the first of the ensemble if convergence true + ## [1.0.1] - 2024-02-12 ### Added diff --git a/CMakeLists.txt b/CMakeLists.txt index 8cc0b8a..072c729 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.20) project(BayesNet - VERSION 1.0.1 + VERSION 1.0.2 DESCRIPTION "Bayesian Network and basic classifiers Library." HOMEPAGE_URL "https://github.com/rmontanana/bayesnet" LANGUAGES CXX diff --git a/lib/catch2 b/lib/catch2 index 766541d..863c662 160000 --- a/lib/catch2 +++ b/lib/catch2 @@ -1 +1 @@ -Subproject commit 766541d12d64845f5232a1ce4e34a85e83506b09 +Subproject commit 863c662c0eff026300f4d729a7054e90d6d12cdd diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index 3b1d0f1..959cada 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -121,6 +121,7 @@ namespace bayesnet { } void BoostAODE::trainModel(const torch::Tensor& weights) { + fitted = true; // Algorithm based on the adaboost algorithm for classification // as explained in Ensemble methods (Zhi-Hua Zhou, 2012) std::unordered_set featuresUsed; @@ -161,7 +162,6 @@ namespace bayesnet { continue; } } - featuresUsed.insert(feature); model = std::make_unique(feature); model->fit(dataset, features, className, states, weights_); auto ypred = model->predict(X_train); @@ -170,6 +170,12 @@ namespace bayesnet { auto mask_right = ypred == y_train; auto masked_weights = weights_ * mask_wrong.to(weights_.dtype()); double epsilon_t = masked_weights.sum().item(); + if (epsilon_t > 0.5) { + // Inverse the weights policy (plot ln(wt)) + // "In each round of AdaBoost, there is a sanity check to ensure that the current base + // learner is better than random guess" (Zhi-Hua Zhou, 2012) + break; + } double wt = (1 - epsilon_t) / epsilon_t; double alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt); // Step 3.2: Update weights for next classifier @@ -181,6 +187,7 @@ namespace bayesnet { double totalWeights = torch::sum(weights_).item(); weights_ = weights_ / totalWeights; // Step 3.4: Store classifier and its accuracy to weigh its future vote + featuresUsed.insert(feature); models.push_back(std::move(model)); significanceModels.push_back(alpha_t); n_models++; @@ -197,15 +204,13 @@ namespace bayesnet { } priorAccuracy = accuracy; } - // epsilon_t > 0.5 => inverse the weights policy (plot ln(wt)) - exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance; + exitCondition = n_models >= maxModels && repeatSparent || count > tolerance; } if (featuresUsed.size() != features.size()) { notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size())); status = WARNING; } notes.push_back("Number of models: " + std::to_string(n_models)); - fitted = true; } std::vector BoostAODE::graph(const std::string& title) const { diff --git a/tests/TestBayesModels.cc b/tests/TestBayesModels.cc index 3e97aab..bed1783 100644 --- a/tests/TestBayesModels.cc +++ b/tests/TestBayesModels.cc @@ -19,7 +19,7 @@ TEST_CASE("Library check version", "[BayesNet]") { auto clf = bayesnet::KDB(2); - REQUIRE(clf.getVersion() == "1.0.1"); + REQUIRE(clf.getVersion() == "1.0.2"); } TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]") { @@ -164,7 +164,8 @@ TEST_CASE("BoostAODE test used features in train note", "[BayesNet]") {"ascending",true}, {"convergence", true}, {"repeatSparent",true}, - {"select_features","CFS"} + {"select_features","CFS"}, + {"tolerance", 3} }); clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); REQUIRE(clf.getNumberOfNodes() == 72);