From fc3d63b7dbe48ce709ebde15b3de24c3a82a408f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 26 Feb 2024 17:07:57 +0100 Subject: [PATCH] change boostaode ascending hyperparameter to order {asc,desc,rand} --- CHANGELOG.md | 6 ++++++ src/BoostAODE.cc | 31 ++++++++++++++++++++----------- src/BoostAODE.h | 4 ++-- tests/TestBayesModels.cc | 37 +++++++++++++++++++++++++++++++------ 4 files changed, 59 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eda5ea4..8bea3e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- Change _ascending_ hyperparameter to _order_ with these possible values _{"asc", "desc", "rand"}_ + ## [1.0.3] ### Added diff --git a/src/BoostAODE.cc b/src/BoostAODE.cc index dc7edb5..9f11f02 100644 --- a/src/BoostAODE.cc +++ b/src/BoostAODE.cc @@ -10,7 +10,7 @@ namespace bayesnet { BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting) { - validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features", "tolerance", "predict_voting" }; + validHyperparameters = { "repeatSparent", "maxModels", "order", "convergence", "threshold", "select_features", "tolerance", "predict_voting" }; } void BoostAODE::buildModel(const torch::Tensor& weights) @@ -57,9 +57,13 @@ namespace bayesnet { maxModels = hyperparameters["maxModels"]; hyperparameters.erase("maxModels"); } - if (hyperparameters.contains("ascending")) { - ascending = hyperparameters["ascending"]; - hyperparameters.erase("ascending"); + if (hyperparameters.contains("order")) { + std::vector algos = { "asc", "desc", "rand" }; + order_algorithm = hyperparameters["order"]; + if (std::find(algos.begin(), algos.end(), order_algorithm) == algos.end()) { + throw std::invalid_argument("Invalid order algorithm, valid values [asc, desc, rand]"); + } + hyperparameters.erase("order"); } if (hyperparameters.contains("convergence")) { convergence = hyperparameters["convergence"]; @@ -81,9 +85,9 @@ namespace bayesnet { auto selectedAlgorithm = hyperparameters["select_features"]; std::vector algos = { "IWSS", "FCBF", "CFS" }; selectFeatures = true; - algorithm = selectedAlgorithm; + select_features_algorithm = selectedAlgorithm; if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) { - throw std::invalid_argument("Invalid selectFeatures value [IWSS, FCBF, CFS]"); + throw std::invalid_argument("Invalid selectFeatures value, valid values [IWSS, FCBF, CFS]"); } hyperparameters.erase("select_features"); } @@ -96,14 +100,14 @@ namespace bayesnet { std::unordered_set featuresUsed; torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); int maxFeatures = 0; - if (algorithm == "CFS") { + if (select_features_algorithm == "CFS") { featureSelector = new CFS(dataset, features, className, maxFeatures, states.at(className).size(), weights_); - } else if (algorithm == "IWSS") { + } else if (select_features_algorithm == "IWSS") { if (threshold < 0 || threshold >0.5) { throw std::invalid_argument("Invalid threshold value for IWSS [0, 0.5]"); } featureSelector = new IWSS(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold); - } else if (algorithm == "FCBF") { + } else if (select_features_algorithm == "FCBF") { if (threshold < 1e-7 || threshold > 1) { throw std::invalid_argument("Invalid threshold value [1e-7, 1]"); } @@ -120,7 +124,7 @@ namespace bayesnet { significanceModels.push_back(1.0); n_models++; } - notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + algorithm); + notes.push_back("Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm); delete featureSelector; return featuresUsed; } @@ -150,10 +154,14 @@ namespace bayesnet { // n_models == maxModels // epsilon sub t > 0.5 => inverse the weights policy // validation error is not decreasing + bool ascending = order_algorithm == "asc"; + std::mt19937 g{ 173 }; while (!exitCondition) { // Step 1: Build ranking with mutual information auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted - std::unique_ptr model; + if (order_algorithm == "rand") { + std::shuffle(featureSelection.begin(), featureSelection.end(), g); + } auto feature = featureSelection[0]; if (!repeatSparent || featuresUsed.size() < featureSelection.size()) { bool used = true; @@ -170,6 +178,7 @@ namespace bayesnet { continue; } } + std::unique_ptr model; model = std::make_unique(feature); model->fit(dataset, features, className, states, weights_); auto ypred = model->predict(X_train); diff --git a/src/BoostAODE.h b/src/BoostAODE.h index 551f4be..7119194 100644 --- a/src/BoostAODE.h +++ b/src/BoostAODE.h @@ -22,10 +22,10 @@ namespace bayesnet { bool repeatSparent = false; // if true, a feature can be selected more than once int maxModels = 0; int tolerance = 0; - bool ascending = false; //Process KBest features ascending or descending order + std::string order_algorithm; // order to process the KBest features asc, desc, rand bool convergence = false; //if true, stop when the model does not improve bool selectFeatures = false; // if true, use feature selection - std::string algorithm = ""; // Selected feature selection algorithm + std::string select_features_algorithm = ""; // Selected feature selection algorithm FeatureSelect* featureSelector = nullptr; double threshold = -1; }; diff --git a/tests/TestBayesModels.cc b/tests/TestBayesModels.cc index eb641fa..0234747 100644 --- a/tests/TestBayesModels.cc +++ b/tests/TestBayesModels.cc @@ -17,7 +17,7 @@ const std::string ACTUAL_VERSION = "1.0.3"; TEST_CASE("Test Bayesian Classifiers score & version", "[BayesNet]") { - map , float> scores = { + map , float> scores{ // Diabetes {{"diabetes", "AODE"}, 0.811198}, {{"diabetes", "KDB"}, 0.852865}, {{"diabetes", "SPODE"}, 0.802083}, {{"diabetes", "TAN"}, 0.821615}, {{"diabetes", "AODELd"}, 0.8138f}, {{"diabetes", "KDBLd"}, 0.80208f}, {{"diabetes", "SPODELd"}, 0.78646f}, {{"diabetes", "TANLd"}, 0.8099f}, {{"diabetes", "BoostAODE"}, 0.83984f}, @@ -31,7 +31,7 @@ TEST_CASE("Test Bayesian Classifiers score & version", "[BayesNet]") {{"iris", "AODE"}, 0.973333}, {{"iris", "KDB"}, 0.973333}, {{"iris", "SPODE"}, 0.973333}, {{"iris", "TAN"}, 0.973333}, {{"iris", "AODELd"}, 0.973333}, {{"iris", "KDBLd"}, 0.973333}, {{"iris", "SPODELd"}, 0.96f}, {{"iris", "TANLd"}, 0.97333f}, {{"iris", "BoostAODE"}, 0.98f} }; - std::map models = { + std::map models{ {"AODE", new bayesnet::AODE()}, {"AODELd", new bayesnet::AODELd()}, {"BoostAODE", new bayesnet::BoostAODE()}, {"KDB", new bayesnet::KDB(2)}, {"KDBLd", new bayesnet::KDBLd(2)}, @@ -104,7 +104,7 @@ TEST_CASE("BoostAODE test used features in train note and score", "[BayesNet]") auto raw = RawDatasets("diabetes", true); auto clf = bayesnet::BoostAODE(true); clf.setHyperparameters({ - {"ascending",true}, + {"order", "asc"}, {"convergence", true}, {"repeatSparent",true}, {"select_features","CFS"}, @@ -168,8 +168,8 @@ TEST_CASE("Model predict_proba", "[BayesNet]") {0, 1, 0}, {0, 1, 0} }); - std::map>> res_prob = { {"TAN", res_prob_tan}, {"SPODE", res_prob_spode} , {"BoostAODEproba", res_prob_baode }, {"BoostAODEvoting", res_prob_voting } }; - std::map models = { {"TAN", new bayesnet::TAN()}, {"SPODE", new bayesnet::SPODE(0)}, {"BoostAODEproba", new bayesnet::BoostAODE(false)}, {"BoostAODEvoting", new bayesnet::BoostAODE(true)} }; + std::map>> res_prob{ {"TAN", res_prob_tan}, {"SPODE", res_prob_spode} , {"BoostAODEproba", res_prob_baode }, {"BoostAODEvoting", res_prob_voting } }; + std::map models{ {"TAN", new bayesnet::TAN()}, {"SPODE", new bayesnet::SPODE(0)}, {"BoostAODEproba", new bayesnet::BoostAODE(false)}, {"BoostAODEvoting", new bayesnet::BoostAODE(true)} }; int init_index = 78; auto raw = RawDatasets("iris", true); @@ -178,9 +178,9 @@ TEST_CASE("Model predict_proba", "[BayesNet]") auto clf = models[model]; clf->fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); auto y_pred_proba = clf->predict_proba(raw.Xv); + auto yt_pred_proba = clf->predict_proba(raw.Xt); auto y_pred = clf->predict(raw.Xv); auto yt_pred = clf->predict(raw.Xt); - auto yt_pred_proba = clf->predict_proba(raw.Xt); REQUIRE(y_pred.size() == yt_pred.size(0)); REQUIRE(y_pred.size() == y_pred_proba.size()); REQUIRE(y_pred.size() == yt_pred_proba.size(0)); @@ -193,6 +193,9 @@ TEST_CASE("Model predict_proba", "[BayesNet]") REQUIRE(predictedClass == y_pred[i]); // Check predict is coherent with predict_proba REQUIRE(yt_pred_proba[i].argmax().item() == y_pred[i]); + for (int j = 0; j < yt_pred_proba.size(1); j++) { + REQUIRE(yt_pred_proba[i][j].item() == Catch::Approx(y_pred_proba[i][j]).epsilon(raw.epsilon)); + } } // Check predict_proba values for vectors and tensors for (int i = 0; i < res_prob.size(); i++) { @@ -222,3 +225,25 @@ TEST_CASE("BoostAODE voting-proba", "[BayesNet]") REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon)); REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon)); } +TEST_CASE("BoostAODE order asc, desc & random", "[BayesNet]") +{ + + auto raw = RawDatasets("glass", true); + std::map scores{ + {"asc", 0.83178f }, { "desc", 0.84579f }, { "rand", 0.83645f } + }; + for (const std::string& order : { "asc", "desc", "rand" }) { + auto clf = bayesnet::BoostAODE(); + clf.setHyperparameters({ + {"order", order}, + }); + clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv); + auto score = clf.score(raw.Xv, raw.yv); + auto scoret = clf.score(raw.Xt, raw.yt); + auto score2 = clf.score(raw.Xv, raw.yv); + auto scoret2 = clf.score(raw.Xt, raw.yt); + INFO("order: " + order); + REQUIRE(score == Catch::Approx(scores[order]).epsilon(raw.epsilon)); + REQUIRE(scoret == Catch::Approx(scores[order]).epsilon(raw.epsilon)); + } +}