From 093c197f0a92ecb74a53cac46ca6cd18e82d2f2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 5 Mar 2024 11:05:11 +0100 Subject: [PATCH] Replace constant strings in BoostAODE --- src/ensembles/BoostAODE.cc | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/ensembles/BoostAODE.cc b/src/ensembles/BoostAODE.cc index 10eb694..7bc32b8 100644 --- a/src/ensembles/BoostAODE.cc +++ b/src/ensembles/BoostAODE.cc @@ -8,6 +8,16 @@ #include "folding.hpp" namespace bayesnet { + struct { + std::string CFS = "CFS"; + std::string FCBF = "FCBF"; + std::string IWSS = "IWSS"; + }SelectFeatures; + struct { + std::string ASC = "asc"; + std::string DESC = "desc"; + std::string RAND = "rand"; + }Orders; BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting) { validHyperparameters = { @@ -61,10 +71,10 @@ namespace bayesnet { hyperparameters.erase("maxModels"); } if (hyperparameters.contains("order")) { - std::vector algos = { "asc", "desc", "rand" }; + std::vector algos = { Orders.ASC, Orders.DESC, Orders.RAND }; order_algorithm = hyperparameters["order"]; if (std::find(algos.begin(), algos.end(), order_algorithm) == algos.end()) { - throw std::invalid_argument("Invalid order algorithm, valid values [asc, desc, rand]"); + throw std::invalid_argument("Invalid order algorithm, valid values [" + Orders.ASC + ", " + Orders.DESC + ", " + Orders.RAND + "]"); } hyperparameters.erase("order"); } @@ -90,11 +100,11 @@ namespace bayesnet { } if (hyperparameters.contains("select_features")) { auto selectedAlgorithm = hyperparameters["select_features"]; - std::vector algos = { "IWSS", "FCBF", "CFS" }; + std::vector algos = { SelectFeatures.IWSS, SelectFeatures.CFS, SelectFeatures.CFS }; selectFeatures = true; select_features_algorithm = selectedAlgorithm; if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) { - throw std::invalid_argument("Invalid selectFeatures value, valid values [IWSS, FCBF, CFS]"); + throw std::invalid_argument("Invalid selectFeatures value, valid values [" + SelectFeatures.IWSS + ", " + SelectFeatures.CFS + ", " + SelectFeatures.FCBF + "]"); } hyperparameters.erase("select_features"); } @@ -107,16 +117,16 @@ namespace bayesnet { std::unordered_set featuresUsed; torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); int maxFeatures = 0; - if (select_features_algorithm == "CFS") { + if (select_features_algorithm == SelectFeatures.CFS) { featureSelector = new CFS(dataset, features, className, maxFeatures, states.at(className).size(), weights_); - } else if (select_features_algorithm == "IWSS") { + } else if (select_features_algorithm == SelectFeatures.IWSS) { if (threshold < 0 || threshold >0.5) { - throw std::invalid_argument("Invalid threshold value for IWSS [0, 0.5]"); + throw std::invalid_argument("Invalid threshold value for " + SelectFeatures.IWSS + " [0, 0.5]"); } featureSelector = new IWSS(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold); - } else if (select_features_algorithm == "FCBF") { + } else if (select_features_algorithm == SelectFeatures.FCBF) { if (threshold < 1e-7 || threshold > 1) { - throw std::invalid_argument("Invalid threshold value [1e-7, 1]"); + throw std::invalid_argument("Invalid threshold value for " + SelectFeatures.FCBF + " [1e-7, 1]"); } featureSelector = new FCBF(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold); } @@ -174,12 +184,12 @@ namespace bayesnet { // n_models == maxModels // epsilon sub t > 0.5 => inverse the weights policy // validation error is not decreasing - bool ascending = order_algorithm == "asc"; + bool ascending = order_algorithm == Orders.ASC; std::mt19937 g{ 173 }; while (!exitCondition) { // Step 1: Build ranking with mutual information auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted - if (order_algorithm == "rand") { + if (order_algorithm == Orders.RAND) { std::shuffle(featureSelection.begin(), featureSelection.end(), g); } auto feature = featureSelection[0];