Replace constant strings in BoostAODE

This commit is contained in:
Ricardo Montañana Gómez 2024-03-05 11:05:11 +01:00
parent 78d7ea7c4d
commit 093c197f0a
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE

View File

@ -8,6 +8,16 @@
#include "folding.hpp" #include "folding.hpp"
namespace bayesnet { namespace bayesnet {
struct {
std::string CFS = "CFS";
std::string FCBF = "FCBF";
std::string IWSS = "IWSS";
}SelectFeatures;
struct {
std::string ASC = "asc";
std::string DESC = "desc";
std::string RAND = "rand";
}Orders;
BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting) BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)
{ {
validHyperparameters = { validHyperparameters = {
@ -61,10 +71,10 @@ namespace bayesnet {
hyperparameters.erase("maxModels"); hyperparameters.erase("maxModels");
} }
if (hyperparameters.contains("order")) { if (hyperparameters.contains("order")) {
std::vector<std::string> algos = { "asc", "desc", "rand" }; std::vector<std::string> algos = { Orders.ASC, Orders.DESC, Orders.RAND };
order_algorithm = hyperparameters["order"]; order_algorithm = hyperparameters["order"];
if (std::find(algos.begin(), algos.end(), order_algorithm) == algos.end()) { if (std::find(algos.begin(), algos.end(), order_algorithm) == algos.end()) {
throw std::invalid_argument("Invalid order algorithm, valid values [asc, desc, rand]"); throw std::invalid_argument("Invalid order algorithm, valid values [" + Orders.ASC + ", " + Orders.DESC + ", " + Orders.RAND + "]");
} }
hyperparameters.erase("order"); hyperparameters.erase("order");
} }
@ -90,11 +100,11 @@ namespace bayesnet {
} }
if (hyperparameters.contains("select_features")) { if (hyperparameters.contains("select_features")) {
auto selectedAlgorithm = hyperparameters["select_features"]; auto selectedAlgorithm = hyperparameters["select_features"];
std::vector<std::string> algos = { "IWSS", "FCBF", "CFS" }; std::vector<std::string> algos = { SelectFeatures.IWSS, SelectFeatures.CFS, SelectFeatures.CFS };
selectFeatures = true; selectFeatures = true;
select_features_algorithm = selectedAlgorithm; select_features_algorithm = selectedAlgorithm;
if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) { if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) {
throw std::invalid_argument("Invalid selectFeatures value, valid values [IWSS, FCBF, CFS]"); throw std::invalid_argument("Invalid selectFeatures value, valid values [" + SelectFeatures.IWSS + ", " + SelectFeatures.CFS + ", " + SelectFeatures.FCBF + "]");
} }
hyperparameters.erase("select_features"); hyperparameters.erase("select_features");
} }
@ -107,16 +117,16 @@ namespace bayesnet {
std::unordered_set<int> featuresUsed; std::unordered_set<int> featuresUsed;
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64); torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
int maxFeatures = 0; int maxFeatures = 0;
if (select_features_algorithm == "CFS") { if (select_features_algorithm == SelectFeatures.CFS) {
featureSelector = new CFS(dataset, features, className, maxFeatures, states.at(className).size(), weights_); featureSelector = new CFS(dataset, features, className, maxFeatures, states.at(className).size(), weights_);
} else if (select_features_algorithm == "IWSS") { } else if (select_features_algorithm == SelectFeatures.IWSS) {
if (threshold < 0 || threshold >0.5) { if (threshold < 0 || threshold >0.5) {
throw std::invalid_argument("Invalid threshold value for IWSS [0, 0.5]"); throw std::invalid_argument("Invalid threshold value for " + SelectFeatures.IWSS + " [0, 0.5]");
} }
featureSelector = new IWSS(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold); featureSelector = new IWSS(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold);
} else if (select_features_algorithm == "FCBF") { } else if (select_features_algorithm == SelectFeatures.FCBF) {
if (threshold < 1e-7 || threshold > 1) { if (threshold < 1e-7 || threshold > 1) {
throw std::invalid_argument("Invalid threshold value [1e-7, 1]"); throw std::invalid_argument("Invalid threshold value for " + SelectFeatures.FCBF + " [1e-7, 1]");
} }
featureSelector = new FCBF(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold); featureSelector = new FCBF(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold);
} }
@ -174,12 +184,12 @@ namespace bayesnet {
// n_models == maxModels // n_models == maxModels
// epsilon sub t > 0.5 => inverse the weights policy // epsilon sub t > 0.5 => inverse the weights policy
// validation error is not decreasing // validation error is not decreasing
bool ascending = order_algorithm == "asc"; bool ascending = order_algorithm == Orders.ASC;
std::mt19937 g{ 173 }; std::mt19937 g{ 173 };
while (!exitCondition) { while (!exitCondition) {
// Step 1: Build ranking with mutual information // Step 1: Build ranking with mutual information
auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted
if (order_algorithm == "rand") { if (order_algorithm == Orders.RAND) {
std::shuffle(featureSelection.begin(), featureSelection.end(), g); std::shuffle(featureSelection.begin(), featureSelection.end(), g);
} }
auto feature = featureSelection[0]; auto feature = featureSelection[0];