Replace constant strings in BoostAODE

This commit is contained in:
Ricardo Montañana Gómez 2024-03-05 11:05:11 +01:00
parent 78d7ea7c4d
commit 093c197f0a
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE

View File

@ -8,6 +8,16 @@
#include "folding.hpp"
namespace bayesnet {
struct {
std::string CFS = "CFS";
std::string FCBF = "FCBF";
std::string IWSS = "IWSS";
}SelectFeatures;
struct {
std::string ASC = "asc";
std::string DESC = "desc";
std::string RAND = "rand";
}Orders;
BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)
{
validHyperparameters = {
@ -61,10 +71,10 @@ namespace bayesnet {
hyperparameters.erase("maxModels");
}
if (hyperparameters.contains("order")) {
std::vector<std::string> algos = { "asc", "desc", "rand" };
std::vector<std::string> algos = { Orders.ASC, Orders.DESC, Orders.RAND };
order_algorithm = hyperparameters["order"];
if (std::find(algos.begin(), algos.end(), order_algorithm) == algos.end()) {
throw std::invalid_argument("Invalid order algorithm, valid values [asc, desc, rand]");
throw std::invalid_argument("Invalid order algorithm, valid values [" + Orders.ASC + ", " + Orders.DESC + ", " + Orders.RAND + "]");
}
hyperparameters.erase("order");
}
@ -90,11 +100,11 @@ namespace bayesnet {
}
if (hyperparameters.contains("select_features")) {
auto selectedAlgorithm = hyperparameters["select_features"];
std::vector<std::string> algos = { "IWSS", "FCBF", "CFS" };
std::vector<std::string> algos = { SelectFeatures.IWSS, SelectFeatures.CFS, SelectFeatures.CFS };
selectFeatures = true;
select_features_algorithm = selectedAlgorithm;
if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) {
throw std::invalid_argument("Invalid selectFeatures value, valid values [IWSS, FCBF, CFS]");
throw std::invalid_argument("Invalid selectFeatures value, valid values [" + SelectFeatures.IWSS + ", " + SelectFeatures.CFS + ", " + SelectFeatures.FCBF + "]");
}
hyperparameters.erase("select_features");
}
@ -107,16 +117,16 @@ namespace bayesnet {
std::unordered_set<int> featuresUsed;
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
int maxFeatures = 0;
if (select_features_algorithm == "CFS") {
if (select_features_algorithm == SelectFeatures.CFS) {
featureSelector = new CFS(dataset, features, className, maxFeatures, states.at(className).size(), weights_);
} else if (select_features_algorithm == "IWSS") {
} else if (select_features_algorithm == SelectFeatures.IWSS) {
if (threshold < 0 || threshold >0.5) {
throw std::invalid_argument("Invalid threshold value for IWSS [0, 0.5]");
throw std::invalid_argument("Invalid threshold value for " + SelectFeatures.IWSS + " [0, 0.5]");
}
featureSelector = new IWSS(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold);
} else if (select_features_algorithm == "FCBF") {
} else if (select_features_algorithm == SelectFeatures.FCBF) {
if (threshold < 1e-7 || threshold > 1) {
throw std::invalid_argument("Invalid threshold value [1e-7, 1]");
throw std::invalid_argument("Invalid threshold value for " + SelectFeatures.FCBF + " [1e-7, 1]");
}
featureSelector = new FCBF(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold);
}
@ -174,12 +184,12 @@ namespace bayesnet {
// n_models == maxModels
// epsilon sub t > 0.5 => inverse the weights policy
// validation error is not decreasing
bool ascending = order_algorithm == "asc";
bool ascending = order_algorithm == Orders.ASC;
std::mt19937 g{ 173 };
while (!exitCondition) {
// Step 1: Build ranking with mutual information
auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted
if (order_algorithm == "rand") {
if (order_algorithm == Orders.RAND) {
std::shuffle(featureSelection.begin(), featureSelection.end(), g);
}
auto feature = featureSelection[0];