Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
a63a35df3f
|
|||
c7555dac3f
|
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
@@ -106,12 +106,12 @@
|
|||||||
"type": "lldb",
|
"type": "lldb",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"name": "test",
|
"name": "test",
|
||||||
"program": "${workspaceFolder}/build_debug/tests/unit_tests",
|
"program": "${workspaceFolder}/build_debug/tests/unit_tests_bayesnet",
|
||||||
"args": [
|
"args": [
|
||||||
"-c=\"Metrics Test\"",
|
//"-c=\"Metrics Test\"",
|
||||||
// "-s",
|
// "-s",
|
||||||
],
|
],
|
||||||
"cwd": "${workspaceFolder}/build/tests",
|
"cwd": "${workspaceFolder}/build_debug/tests",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Build & debug active file",
|
"name": "Build & debug active file",
|
||||||
|
@@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file.
|
|||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [1.0.2] - 2024-02-20
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Fix bug in BoostAODE: do not include the model if epsilon sub t is greater than 0.5
|
||||||
|
- Fix bug in BoostAODE: compare accuracy with previous accuracy instead of the first of the ensemble if convergence true
|
||||||
|
|
||||||
## [1.0.1] - 2024-02-12
|
## [1.0.1] - 2024-02-12
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
project(BayesNet
|
project(BayesNet
|
||||||
VERSION 1.0.1
|
VERSION 1.0.2
|
||||||
DESCRIPTION "Bayesian Network and basic classifiers Library."
|
DESCRIPTION "Bayesian Network and basic classifiers Library."
|
||||||
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
|
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
|
||||||
LANGUAGES CXX
|
LANGUAGES CXX
|
||||||
|
BIN
docs/BoostAODE.docx
Normal file
BIN
docs/BoostAODE.docx
Normal file
Binary file not shown.
1
lib/argparse
Submodule
1
lib/argparse
Submodule
Submodule lib/argparse added at 69dabd88a8
1
lib/libxlsxwriter
Submodule
1
lib/libxlsxwriter
Submodule
Submodule lib/libxlsxwriter added at 29355a0887
@@ -121,6 +121,9 @@ namespace bayesnet {
|
|||||||
}
|
}
|
||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
|
fitted = true;
|
||||||
|
// Algorithm based on the adaboost algorithm for classification
|
||||||
|
// as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
|
||||||
std::unordered_set<int> featuresUsed;
|
std::unordered_set<int> featuresUsed;
|
||||||
if (selectFeatures) {
|
if (selectFeatures) {
|
||||||
featuresUsed = initializeModels();
|
featuresUsed = initializeModels();
|
||||||
@@ -132,9 +135,8 @@ namespace bayesnet {
|
|||||||
// Variables to control the accuracy finish condition
|
// Variables to control the accuracy finish condition
|
||||||
double priorAccuracy = 0.0;
|
double priorAccuracy = 0.0;
|
||||||
double delta = 1.0;
|
double delta = 1.0;
|
||||||
double threshold = 1e-4;
|
double convergence_threshold = 1e-4;
|
||||||
int count = 0; // number of times the accuracy is lower than the threshold
|
int count = 0; // number of times the accuracy is lower than the convergence_threshold
|
||||||
fitted = true; // to enable predict
|
|
||||||
// Step 0: Set the finish condition
|
// Step 0: Set the finish condition
|
||||||
// if not repeatSparent a finish condition is run out of features
|
// if not repeatSparent a finish condition is run out of features
|
||||||
// n_models == maxModels
|
// n_models == maxModels
|
||||||
@@ -160,7 +162,6 @@ namespace bayesnet {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
featuresUsed.insert(feature);
|
|
||||||
model = std::make_unique<SPODE>(feature);
|
model = std::make_unique<SPODE>(feature);
|
||||||
model->fit(dataset, features, className, states, weights_);
|
model->fit(dataset, features, className, states, weights_);
|
||||||
auto ypred = model->predict(X_train);
|
auto ypred = model->predict(X_train);
|
||||||
@@ -169,6 +170,12 @@ namespace bayesnet {
|
|||||||
auto mask_right = ypred == y_train;
|
auto mask_right = ypred == y_train;
|
||||||
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
|
auto masked_weights = weights_ * mask_wrong.to(weights_.dtype());
|
||||||
double epsilon_t = masked_weights.sum().item<double>();
|
double epsilon_t = masked_weights.sum().item<double>();
|
||||||
|
if (epsilon_t > 0.5) {
|
||||||
|
// Inverse the weights policy (plot ln(wt))
|
||||||
|
// "In each round of AdaBoost, there is a sanity check to ensure that the current base
|
||||||
|
// learner is better than random guess" (Zhi-Hua Zhou, 2012)
|
||||||
|
break;
|
||||||
|
}
|
||||||
double wt = (1 - epsilon_t) / epsilon_t;
|
double wt = (1 - epsilon_t) / epsilon_t;
|
||||||
double alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt);
|
double alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt);
|
||||||
// Step 3.2: Update weights for next classifier
|
// Step 3.2: Update weights for next classifier
|
||||||
@@ -180,6 +187,7 @@ namespace bayesnet {
|
|||||||
double totalWeights = torch::sum(weights_).item<double>();
|
double totalWeights = torch::sum(weights_).item<double>();
|
||||||
weights_ = weights_ / totalWeights;
|
weights_ = weights_ / totalWeights;
|
||||||
// Step 3.4: Store classifier and its accuracy to weigh its future vote
|
// Step 3.4: Store classifier and its accuracy to weigh its future vote
|
||||||
|
featuresUsed.insert(feature);
|
||||||
models.push_back(std::move(model));
|
models.push_back(std::move(model));
|
||||||
significanceModels.push_back(alpha_t);
|
significanceModels.push_back(alpha_t);
|
||||||
n_models++;
|
n_models++;
|
||||||
@@ -191,11 +199,12 @@ namespace bayesnet {
|
|||||||
} else {
|
} else {
|
||||||
delta = accuracy - priorAccuracy;
|
delta = accuracy - priorAccuracy;
|
||||||
}
|
}
|
||||||
if (delta < threshold) {
|
if (delta < convergence_threshold) {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
priorAccuracy = accuracy;
|
||||||
}
|
}
|
||||||
exitCondition = n_models >= maxModels && repeatSparent || epsilon_t > 0.5 || count > tolerance;
|
exitCondition = n_models >= maxModels && repeatSparent || count > tolerance;
|
||||||
}
|
}
|
||||||
if (featuresUsed.size() != features.size()) {
|
if (featuresUsed.size() != features.size()) {
|
||||||
notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()));
|
notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()));
|
||||||
|
@@ -19,7 +19,7 @@
|
|||||||
TEST_CASE("Library check version", "[BayesNet]")
|
TEST_CASE("Library check version", "[BayesNet]")
|
||||||
{
|
{
|
||||||
auto clf = bayesnet::KDB(2);
|
auto clf = bayesnet::KDB(2);
|
||||||
REQUIRE(clf.getVersion() == "1.0.1");
|
REQUIRE(clf.getVersion() == "1.0.2");
|
||||||
}
|
}
|
||||||
TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
|
TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
|
||||||
{
|
{
|
||||||
@@ -164,7 +164,8 @@ TEST_CASE("BoostAODE test used features in train note", "[BayesNet]")
|
|||||||
{"ascending",true},
|
{"ascending",true},
|
||||||
{"convergence", true},
|
{"convergence", true},
|
||||||
{"repeatSparent",true},
|
{"repeatSparent",true},
|
||||||
{"select_features","CFS"}
|
{"select_features","CFS"},
|
||||||
|
{"tolerance", 3}
|
||||||
});
|
});
|
||||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
REQUIRE(clf.getNumberOfNodes() == 72);
|
REQUIRE(clf.getNumberOfNodes() == 72);
|
||||||
|
Reference in New Issue
Block a user