TestXBAODE complete, fix XBAODE error in no convergence & 99% coverage

This commit is contained in:
2025-03-13 01:28:48 +01:00
parent b1d317d8f4
commit 4ded6f51eb
3 changed files with 308 additions and 348 deletions

View File

@@ -7,7 +7,7 @@
[![Security Rating](https://sonarcloud.io/api/project_badges/measure?project=rmontanana_BayesNet&metric=security_rating)](https://sonarcloud.io/summary/new_code?id=rmontanana_BayesNet) [![Security Rating](https://sonarcloud.io/api/project_badges/measure?project=rmontanana_BayesNet&metric=security_rating)](https://sonarcloud.io/summary/new_code?id=rmontanana_BayesNet)
[![Reliability Rating](https://sonarcloud.io/api/project_badges/measure?project=rmontanana_BayesNet&metric=reliability_rating)](https://sonarcloud.io/summary/new_code?id=rmontanana_BayesNet) [![Reliability Rating](https://sonarcloud.io/api/project_badges/measure?project=rmontanana_BayesNet&metric=reliability_rating)](https://sonarcloud.io/summary/new_code?id=rmontanana_BayesNet)
![Gitea Last Commit](https://img.shields.io/gitea/last-commit/rmontanana/bayesnet?gitea_url=https://gitea.rmontanana.es:3000&logo=gitea) ![Gitea Last Commit](https://img.shields.io/gitea/last-commit/rmontanana/bayesnet?gitea_url=https://gitea.rmontanana.es:3000&logo=gitea)
[![Coverage Badge](https://img.shields.io/badge/Coverage-98,2%25-green)](html/index.html) [![Coverage Badge](https://img.shields.io/badge/Coverage-99,0%25-green)](html/index.html)
[![DOI](https://zenodo.org/badge/667782806.svg)](https://doi.org/10.5281/zenodo.14210344) [![DOI](https://zenodo.org/badge/667782806.svg)](https://doi.org/10.5281/zenodo.14210344)
Bayesian Network Classifiers library Bayesian Network Classifiers library

View File

@@ -12,10 +12,8 @@
namespace bayesnet { namespace bayesnet {
XBAODE::XBAODE() : Boost(false) { XBAODE::XBAODE() : Boost(false) {
validHyperparameters = { validHyperparameters = {"alpha_block", "order", "convergence", "convergence_best", "bisection",
"alpha_block", "order", "convergence", "threshold", "maxTolerance", "predict_voting", "select_features"};
"convergence_best", "bisection", "threshold",
"maxTolerance", "predict_voting", "select_features"};
} }
void XBAODE::add_model(std::unique_ptr<Classifier> model, double significance) { void XBAODE::add_model(std::unique_ptr<Classifier> model, double significance) {
models.push_back(std::move(model)); models.push_back(std::move(model));
@@ -35,18 +33,17 @@ std::vector<int> XBAODE::initializeModels(const Smoothing_t smoothing) {
model->fit(dataset, features, className, states, weights_, smoothing); model->fit(dataset, features, className, states, weights_, smoothing);
add_model(std::move(model), 1.0); add_model(std::move(model), 1.0);
} }
notes.push_back("Used features in initialization: " + notes.push_back("Used features in initialization: " + std::to_string(featuresSelected.size()) + " of " +
std::to_string(featuresSelected.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);
std::to_string(features.size()) + " with " +
select_features_algorithm);
return featuresSelected; return featuresSelected;
} }
void XBAODE::trainModel(const torch::Tensor &weights, void XBAODE::trainModel(const torch::Tensor &weights, const bayesnet::Smoothing_t smoothing) {
const bayesnet::Smoothing_t smoothing) {
X_train_ = TensorUtils::to_matrix(X_train); X_train_ = TensorUtils::to_matrix(X_train);
y_train_ = TensorUtils::to_vector<int>(y_train); y_train_ = TensorUtils::to_vector<int>(y_train);
if (convergence) {
X_test_ = TensorUtils::to_matrix(X_test); X_test_ = TensorUtils::to_matrix(X_test);
y_test_ = TensorUtils::to_vector<int>(y_test); y_test_ = TensorUtils::to_vector<int>(y_test);
}
fitted = true; fitted = true;
double alpha_t; double alpha_t;
torch::Tensor weights_ = torch::full({m}, 1.0 / m, torch::kFloat64); torch::Tensor weights_ = torch::full({m}, 1.0 / m, torch::kFloat64);
@@ -57,8 +54,7 @@ void XBAODE::trainModel(const torch::Tensor &weights,
featuresUsed = initializeModels(smoothing); featuresUsed = initializeModels(smoothing);
auto ypred = predict(X_train_); auto ypred = predict(X_train_);
auto ypred_t = torch::tensor(ypred); auto ypred_t = torch::tensor(ypred);
std::tie(weights_, alpha_t, finished) = std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred_t, weights_);
update_weights(y_train, ypred_t, weights_);
// Update significance of the models // Update significance of the models
for (const int &feature : featuresUsed) { for (const int &feature : featuresUsed) {
significanceModels.pop_back(); significanceModels.pop_back();
@@ -72,14 +68,12 @@ void XBAODE::trainModel(const torch::Tensor &weights,
return; return;
} }
} }
int numItemsPack = int numItemsPack = 0; // The counter of the models inserted in the current pack
0; // The counter of the models inserted in the current pack
// Variables to control the accuracy finish condition // Variables to control the accuracy finish condition
double priorAccuracy = 0.0; double priorAccuracy = 0.0;
double improvement = 1.0; double improvement = 1.0;
double convergence_threshold = 1e-4; double convergence_threshold = 1e-4;
int tolerance = int tolerance = 0; // number of times the accuracy is lower than the convergence_threshold
0; // number of times the accuracy is lower than the convergence_threshold
// Step 0: Set the finish condition // Step 0: Set the finish condition
// epsilon sub t > 0.5 => inverse the weights_ policy // epsilon sub t > 0.5 => inverse the weights_ policy
// validation error is not decreasing // validation error is not decreasing
@@ -88,15 +82,15 @@ void XBAODE::trainModel(const torch::Tensor &weights,
std::mt19937 g{173}; std::mt19937 g{173};
while (!finished) { while (!finished) {
// Step 1: Build ranking with mutual information // Step 1: Build ranking with mutual information
auto featureSelection = metrics.SelectKBestWeighted( auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted
weights_, ascending, n); // Get all the features sorted
if (order_algorithm == bayesnet::Orders.RAND) { if (order_algorithm == bayesnet::Orders.RAND) {
std::shuffle(featureSelection.begin(), featureSelection.end(), g); std::shuffle(featureSelection.begin(), featureSelection.end(), g);
} }
// Remove used features // Remove used features
featureSelection.erase( featureSelection.erase(remove_if(featureSelection.begin(), featureSelection.end(),
remove_if(featureSelection.begin(), featureSelection.end(), [&](auto x) { [&](auto x) {
return std::find(featuresUsed.begin(), featuresUsed.end(), x) != featuresUsed.end(); return std::find(featuresUsed.begin(), featuresUsed.end(), x) !=
featuresUsed.end();
}), }),
featureSelection.end()); featureSelection.end());
int k = bisection ? pow(2, tolerance) : 1; int k = bisection ? pow(2, tolerance) : 1;
@@ -124,6 +118,7 @@ void XBAODE::trainModel(const torch::Tensor &weights,
add_model(std::move(model), 1.0); add_model(std::move(model), 1.0);
// Compute the prediction // Compute the prediction
ypred = predict(X_train_); ypred = predict(X_train_);
model = std::move(models.back());
// Remove the model from the ensemble // Remove the model from the ensemble
remove_last_model(); remove_last_model();
} else { } else {
@@ -142,8 +137,7 @@ void XBAODE::trainModel(const torch::Tensor &weights,
} // End of the pack } // End of the pack
if (convergence && !finished) { if (convergence && !finished) {
auto y_val_predict = predict(X_test); auto y_val_predict = predict(X_test);
double accuracy = (y_val_predict == y_test).sum().item<double>() / double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
(double)y_test.size(0);
if (priorAccuracy == 0) { if (priorAccuracy == 0) {
priorAccuracy = accuracy; priorAccuracy = accuracy;
} else { } else {
@@ -171,17 +165,14 @@ void XBAODE::trainModel(const torch::Tensor &weights,
} }
// VLOG_SCOPE_F(1, "tolerance: %d featuresUsed.size: %zu features.size: // VLOG_SCOPE_F(1, "tolerance: %d featuresUsed.size: %zu features.size:
// %zu", tolerance, featuresUsed.size(), features.size()); // %zu", tolerance, featuresUsed.size(), features.size());
finished = finished || tolerance > maxTolerance || finished = finished || tolerance > maxTolerance || featuresUsed.size() == features.size();
featuresUsed.size() == features.size();
} }
if (tolerance > maxTolerance) { if (tolerance > maxTolerance) {
if (numItemsPack < n_models) { if (numItemsPack < n_models) {
notes.push_back("Convergence threshold reached & " + notes.push_back("Convergence threshold reached & " + std::to_string(numItemsPack) + " models eliminated");
std::to_string(numItemsPack) + " models eliminated");
// VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated // VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated
// of %d", numItemsPack, n_models); // of %d", numItemsPack, n_models);
for (int i = featuresUsed.size() - 1; for (int i = featuresUsed.size() - 1; i >= featuresUsed.size() - numItemsPack; --i) {
i >= featuresUsed.size() - numItemsPack; --i) {
remove_last_model(); remove_last_model();
} }
// VLOG_SCOPE_F(4, "*Convergence threshold %d models left & %d features // VLOG_SCOPE_F(4, "*Convergence threshold %d models left & %d features
@@ -193,7 +184,8 @@ void XBAODE::trainModel(const torch::Tensor &weights,
} }
} }
if (featuresUsed.size() != features.size()) { if (featuresUsed.size() != features.size()) {
notes.push_back( "Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size())); notes.push_back("Used features in train: " + std::to_string(featuresUsed.size()) + " of " +
std::to_string(features.size()));
status = bayesnet::WARNING; status = bayesnet::WARNING;
} }
notes.push_back("Number of models: " + std::to_string(n_models)); notes.push_back("Number of models: " + std::to_string(n_models));

View File

@@ -4,12 +4,12 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// *************************************************************** // ***************************************************************
#include "TestUtils.h"
#include "bayesnet/ensembles/XBAODE.h"
#include <catch2/catch_approx.hpp> #include <catch2/catch_approx.hpp>
#include <catch2/catch_test_macros.hpp> #include <catch2/catch_test_macros.hpp>
#include <catch2/generators/catch_generators.hpp> #include <catch2/generators/catch_generators.hpp>
#include <catch2/matchers/catch_matchers.hpp> #include <catch2/matchers/catch_matchers.hpp>
#include "TestUtils.h"
#include "bayesnet/ensembles/XBAODE.h"
TEST_CASE("Normal test", "[XBAODE]") { TEST_CASE("Normal test", "[XBAODE]") {
auto raw = RawDatasets("iris", true); auto raw = RawDatasets("iris", true);
@@ -78,171 +78,139 @@ TEST_CASE("Test used features in train note and score", "[XBAODE]") {
REQUIRE(score == Catch::Approx(0.819010437f).epsilon(raw.epsilon)); REQUIRE(score == Catch::Approx(0.819010437f).epsilon(raw.epsilon));
REQUIRE(scoret == Catch::Approx(0.819010437f).epsilon(raw.epsilon)); REQUIRE(scoret == Catch::Approx(0.819010437f).epsilon(raw.epsilon));
} }
// TEST_CASE("Voting vs proba", "[XBAODE]") TEST_CASE("Order asc, desc & random", "[XBAODE]") {
// { auto raw = RawDatasets("glass", true);
// auto raw = RawDatasets("iris", true); std::map<std::string, double> scores{{"asc", 0.83645f}, {"desc", 0.84579f}, {"rand", 0.84112}};
// auto clf = bayesnet::XBAODE(false); for (const std::string &order : {"asc", "desc", "rand"}) {
// clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, auto clf = bayesnet::XBAODE();
// raw.smoothing); auto score_proba = clf.score(raw.Xv, raw.yv); auto clf.setHyperparameters({
// pred_proba = clf.predict_proba(raw.Xv); clf.setHyperparameters({ {"order", order},
// {"predict_voting",true}, {"bisection", false},
// }); {"maxTolerance", 1},
// auto score_voting = clf.score(raw.Xv, raw.yv); {"convergence", false},
// auto pred_voting = clf.predict_proba(raw.Xv); });
// REQUIRE(score_proba == Catch::Approx(0.97333).epsilon(raw.epsilon)); clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
// REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon)); auto score = clf.score(raw.Xv, raw.yv);
// REQUIRE(pred_voting[83][2] == Catch::Approx(1.0).epsilon(raw.epsilon)); auto scoret = clf.score(raw.Xt, raw.yt);
// REQUIRE(pred_proba[83][2] == INFO("XBAODE order: " << order);
// Catch::Approx(0.86121525).epsilon(raw.epsilon)); REQUIRE(clf.dump_cpt() REQUIRE(score == Catch::Approx(scores[order]).epsilon(raw.epsilon));
// == ""); REQUIRE(clf.topological_order() == std::vector<std::string>()); REQUIRE(scoret == Catch::Approx(scores[order]).epsilon(raw.epsilon));
// } }
// TEST_CASE("Order asc, desc & random", "[XBAODE]") }
// { TEST_CASE("Oddities", "[XBAODE]") {
// auto raw = RawDatasets("glass", true); auto clf = bayesnet::XBAODE();
// std::map<std::string, double> scores{ auto raw = RawDatasets("iris", true);
// {"asc", 0.83645f }, { "desc", 0.84579f }, { "rand", 0.84112 } auto bad_hyper = nlohmann::json{
// }; {{"order", "duck"}},
// for (const std::string& order : { "asc", "desc", "rand" }) { {{"select_features", "duck"}},
// auto clf = bayesnet::XBAODE(); {{"maxTolerance", 0}},
// clf.setHyperparameters({ {{"maxTolerance", 7}},
// {"order", order}, };
// {"bisection", false}, for (const auto &hyper : bad_hyper.items()) {
// {"maxTolerance", 1}, INFO("XBAODE hyper: " << hyper.value().dump());
// {"convergence", false}, REQUIRE_THROWS_AS(clf.setHyperparameters(hyper.value()), std::invalid_argument);
// }); }
// clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, REQUIRE_THROWS_AS(clf.setHyperparameters({{"maxTolerance", 0}}), std::invalid_argument);
// raw.smoothing); auto score = clf.score(raw.Xv, raw.yv); auto scoret = auto bad_hyper_fit = nlohmann::json{
// clf.score(raw.Xt, raw.yt); INFO("XBAODE order: " << order); {{"select_features", "IWSS"}, {"threshold", -0.01}},
// REQUIRE(score == Catch::Approx(scores[order]).epsilon(raw.epsilon)); {{"select_features", "IWSS"}, {"threshold", 0.51}},
// REQUIRE(scoret == Catch::Approx(scores[order]).epsilon(raw.epsilon)); {{"select_features", "FCBF"}, {"threshold", 1e-8}},
// } {{"select_features", "FCBF"}, {"threshold", 1.01}},
// } };
// TEST_CASE("Oddities", "[XBAODE]") for (const auto &hyper : bad_hyper_fit.items()) {
// { INFO("XBAODE hyper: " << hyper.value().dump());
// auto clf = bayesnet::XBAODE(); clf.setHyperparameters(hyper.value());
// auto raw = RawDatasets("iris", true); REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing),
// auto bad_hyper = nlohmann::json{ std::invalid_argument);
// { { "order", "duck" } }, }
// { { "select_features", "duck" } }, auto bad_hyper_fit2 = nlohmann::json{
// { { "maxTolerance", 0 } }, {{"alpha_block", true}, {"block_update", true}},
// { { "maxTolerance", 7 } }, {{"bisection", false}, {"block_update", true}},
// }; };
// for (const auto& hyper : bad_hyper.items()) { for (const auto &hyper : bad_hyper_fit2.items()) {
// INFO("XBAODE hyper: " << hyper.value().dump()); INFO("XBAODE hyper: " << hyper.value().dump());
// REQUIRE_THROWS_AS(clf.setHyperparameters(hyper.value()), REQUIRE_THROWS_AS(clf.setHyperparameters(hyper.value()), std::invalid_argument);
// std::invalid_argument); }
// } }
// REQUIRE_THROWS_AS(clf.setHyperparameters({ {"maxTolerance", 0 } }), TEST_CASE("Bisection Best", "[XBAODE]") {
// std::invalid_argument); auto bad_hyper_fit = nlohmann::json{ auto clf = bayesnet::XBAODE();
// { { "select_features","IWSS" }, { "threshold", -0.01 } }, auto raw = RawDatasets("kdd_JapaneseVowels", true, 1200, true, false);
// { { "select_features","IWSS" }, { "threshold", 0.51 } }, clf.setHyperparameters({
// { { "select_features","FCBF" }, { "threshold", 1e-8 } }, {"bisection", true},
// { { "select_features","FCBF" }, { "threshold", 1.01 } }, {"maxTolerance", 3},
// }; {"convergence", true},
// for (const auto& hyper : bad_hyper_fit.items()) { {"convergence_best", false},
// INFO("XBAODE hyper: " << hyper.value().dump()); });
// clf.setHyperparameters(hyper.value()); clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
// REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, REQUIRE(clf.getNumberOfNodes() == 210);
// raw.className, raw.states, raw.smoothing), std::invalid_argument); REQUIRE(clf.getNumberOfEdges() == 406);
// } REQUIRE(clf.getNotes().size() == 1);
REQUIRE(clf.getNotes().at(0) == "Number of models: 14");
// auto bad_hyper_fit2 = nlohmann::json{ auto score = clf.score(raw.X_test, raw.y_test);
// { { "alpha_block", true }, { "block_update", true } }, auto scoret = clf.score(raw.X_test, raw.y_test);
// { { "bisection", false }, { "block_update", true } }, REQUIRE(score == Catch::Approx(0.991666675f).epsilon(raw.epsilon));
// }; REQUIRE(scoret == Catch::Approx(0.991666675f).epsilon(raw.epsilon));
// for (const auto& hyper : bad_hyper_fit2.items()) { }
// INFO("XBAODE hyper: " << hyper.value().dump()); TEST_CASE("Bisection Best vs Last", "[XBAODE]") {
// REQUIRE_THROWS_AS(clf.setHyperparameters(hyper.value()), auto raw = RawDatasets("kdd_JapaneseVowels", true, 1500, true, false);
// std::invalid_argument); auto clf = bayesnet::XBAODE();
// } auto hyperparameters = nlohmann::json{
// } {"bisection", true},
// TEST_CASE("Bisection Best", "[XBAODE]") {"maxTolerance", 3},
// { {"convergence", true},
// auto clf = bayesnet::XBAODE(); {"convergence_best", true},
// auto raw = RawDatasets("kdd_JapaneseVowels", true, 1200, true, false); };
// clf.setHyperparameters({ clf.setHyperparameters(hyperparameters);
// {"bisection", true}, clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
// {"maxTolerance", 3}, auto score_best = clf.score(raw.X_test, raw.y_test);
// {"convergence", true}, REQUIRE(score_best == Catch::Approx(0.973333359f).epsilon(raw.epsilon));
// {"convergence_best", false}, // Now we will set the hyperparameter to use the last accuracy
// }); hyperparameters["convergence_best"] = false;
// clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, clf.setHyperparameters(hyperparameters);
// raw.states, raw.smoothing); REQUIRE(clf.getNumberOfNodes() == 210); clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
// REQUIRE(clf.getNumberOfEdges() == 378); auto score_last = clf.score(raw.X_test, raw.y_test);
// REQUIRE(clf.getNotes().size() == 1); REQUIRE(score_last == Catch::Approx(0.976666689f).epsilon(raw.epsilon));
// REQUIRE(clf.getNotes().at(0) == "Number of models: 14"); }
// auto score = clf.score(raw.X_test, raw.y_test); TEST_CASE("Block Update", "[XBAODE]") {
// auto scoret = clf.score(raw.X_test, raw.y_test); auto clf = bayesnet::XBAODE();
// REQUIRE(score == Catch::Approx(0.991666675f).epsilon(raw.epsilon)); auto raw = RawDatasets("mfeat-factors", true, 500);
// REQUIRE(scoret == Catch::Approx(0.991666675f).epsilon(raw.epsilon)); clf.setHyperparameters({
// } {"bisection", true},
// TEST_CASE("Bisection Best vs Last", "[XBAODE]") {"block_update", true},
// { {"maxTolerance", 3},
// auto raw = RawDatasets("kdd_JapaneseVowels", true, 1500, true, false); {"convergence", true},
// auto clf = bayesnet::XBAODE(true); });
// auto hyperparameters = nlohmann::json{ clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
// {"bisection", true}, REQUIRE(clf.getNumberOfNodes() == 1085);
// {"maxTolerance", 3}, REQUIRE(clf.getNumberOfEdges() == 2165);
// {"convergence", true}, REQUIRE(clf.getNotes().size() == 3);
// {"convergence_best", true}, REQUIRE(clf.getNotes()[0] == "Convergence threshold reached & 15 models eliminated");
// }; REQUIRE(clf.getNotes()[1] == "Used features in train: 20 of 216");
// clf.setHyperparameters(hyperparameters); REQUIRE(clf.getNotes()[2] == "Number of models: 5");
// clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, auto score = clf.score(raw.X_test, raw.y_test);
// raw.states, raw.smoothing); auto score_best = clf.score(raw.X_test, auto scoret = clf.score(raw.X_test, raw.y_test);
// raw.y_test); REQUIRE(score_best == REQUIRE(score == Catch::Approx(1.0f).epsilon(raw.epsilon));
// Catch::Approx(0.980000019f).epsilon(raw.epsilon)); REQUIRE(scoret == Catch::Approx(1.0f).epsilon(raw.epsilon));
// // Now we will set the hyperparameter to use the last accuracy //
// hyperparameters["convergence_best"] = false; // std::cout << "Number of nodes " << clf.getNumberOfNodes() << std::endl;
// clf.setHyperparameters(hyperparameters); // std::cout << "Number of edges " << clf.getNumberOfEdges() << std::endl;
// clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, // std::cout << "Notes size " << clf.getNotes().size() << std::endl;
// raw.states, raw.smoothing); auto score_last = clf.score(raw.X_test, // for (auto note : clf.getNotes()) {
// raw.y_test); REQUIRE(score_last == // std::cout << note << std::endl;
// Catch::Approx(0.976666689f).epsilon(raw.epsilon)); // }
// } // std::cout << "Score " << score << std::endl;
// TEST_CASE("Block Update", "[XBAODE]") }
// { TEST_CASE("Alphablock", "[XBAODE]") {
// auto clf = bayesnet::XBAODE(); auto clf_alpha = bayesnet::XBAODE();
// auto raw = RawDatasets("mfeat-factors", true, 500); auto clf_no_alpha = bayesnet::XBAODE();
// clf.setHyperparameters({ auto raw = RawDatasets("diabetes", true);
// {"bisection", true}, clf_alpha.setHyperparameters({
// {"block_update", true}, {"alpha_block", true},
// {"maxTolerance", 3}, });
// {"convergence", true}, clf_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
// }); clf_no_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
// clf.fit(raw.X_train, raw.y_train, raw.features, raw.className, auto score_alpha = clf_alpha.score(raw.X_test, raw.y_test);
// raw.states, raw.smoothing); REQUIRE(clf.getNumberOfNodes() == 868); auto score_no_alpha = clf_no_alpha.score(raw.X_test, raw.y_test);
// REQUIRE(clf.getNumberOfEdges() == 1724); REQUIRE(score_alpha == Catch::Approx(0.720779f).epsilon(raw.epsilon));
// REQUIRE(clf.getNotes().size() == 3); REQUIRE(score_no_alpha == Catch::Approx(0.733766f).epsilon(raw.epsilon));
// REQUIRE(clf.getNotes()[0] == "Convergence threshold reached & 15 models }
// eliminated"); REQUIRE(clf.getNotes()[1] == "Used features in train: 19 of
// 216"); REQUIRE(clf.getNotes()[2] == "Number of models: 4"); auto score =
// clf.score(raw.X_test, raw.y_test); auto scoret = clf.score(raw.X_test,
// raw.y_test); REQUIRE(score == Catch::Approx(0.99f).epsilon(raw.epsilon));
// REQUIRE(scoret == Catch::Approx(0.99f).epsilon(raw.epsilon));
// //
// // std::cout << "Number of nodes " << clf.getNumberOfNodes() <<
// std::endl;
// // std::cout << "Number of edges " << clf.getNumberOfEdges() <<
// std::endl;
// // std::cout << "Notes size " << clf.getNotes().size() << std::endl;
// // for (auto note : clf.getNotes()) {
// // std::cout << note << std::endl;
// // }
// // std::cout << "Score " << score << std::endl;
// }
// TEST_CASE("Alphablock", "[XBAODE]")
// {
// auto clf_alpha = bayesnet::XBAODE();
// auto clf_no_alpha = bayesnet::XBAODE();
// auto raw = RawDatasets("diabetes", true);
// clf_alpha.setHyperparameters({
// {"alpha_block", true},
// });
// clf_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className,
// raw.states, raw.smoothing); clf_no_alpha.fit(raw.X_train, raw.y_train,
// raw.features, raw.className, raw.states, raw.smoothing); auto score_alpha
// = clf_alpha.score(raw.X_test, raw.y_test); auto score_no_alpha =
// clf_no_alpha.score(raw.X_test, raw.y_test); REQUIRE(score_alpha ==
// Catch::Approx(0.720779f).epsilon(raw.epsilon)); REQUIRE(score_no_alpha ==
// Catch::Approx(0.733766f).epsilon(raw.epsilon));
// }