Replace constant strings in BoostAODE
This commit is contained in:
parent
78d7ea7c4d
commit
093c197f0a
@ -8,6 +8,16 @@
|
||||
#include "folding.hpp"
|
||||
|
||||
namespace bayesnet {
|
||||
struct {
|
||||
std::string CFS = "CFS";
|
||||
std::string FCBF = "FCBF";
|
||||
std::string IWSS = "IWSS";
|
||||
}SelectFeatures;
|
||||
struct {
|
||||
std::string ASC = "asc";
|
||||
std::string DESC = "desc";
|
||||
std::string RAND = "rand";
|
||||
}Orders;
|
||||
BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)
|
||||
{
|
||||
validHyperparameters = {
|
||||
@ -61,10 +71,10 @@ namespace bayesnet {
|
||||
hyperparameters.erase("maxModels");
|
||||
}
|
||||
if (hyperparameters.contains("order")) {
|
||||
std::vector<std::string> algos = { "asc", "desc", "rand" };
|
||||
std::vector<std::string> algos = { Orders.ASC, Orders.DESC, Orders.RAND };
|
||||
order_algorithm = hyperparameters["order"];
|
||||
if (std::find(algos.begin(), algos.end(), order_algorithm) == algos.end()) {
|
||||
throw std::invalid_argument("Invalid order algorithm, valid values [asc, desc, rand]");
|
||||
throw std::invalid_argument("Invalid order algorithm, valid values [" + Orders.ASC + ", " + Orders.DESC + ", " + Orders.RAND + "]");
|
||||
}
|
||||
hyperparameters.erase("order");
|
||||
}
|
||||
@ -90,11 +100,11 @@ namespace bayesnet {
|
||||
}
|
||||
if (hyperparameters.contains("select_features")) {
|
||||
auto selectedAlgorithm = hyperparameters["select_features"];
|
||||
std::vector<std::string> algos = { "IWSS", "FCBF", "CFS" };
|
||||
std::vector<std::string> algos = { SelectFeatures.IWSS, SelectFeatures.CFS, SelectFeatures.CFS };
|
||||
selectFeatures = true;
|
||||
select_features_algorithm = selectedAlgorithm;
|
||||
if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) {
|
||||
throw std::invalid_argument("Invalid selectFeatures value, valid values [IWSS, FCBF, CFS]");
|
||||
throw std::invalid_argument("Invalid selectFeatures value, valid values [" + SelectFeatures.IWSS + ", " + SelectFeatures.CFS + ", " + SelectFeatures.FCBF + "]");
|
||||
}
|
||||
hyperparameters.erase("select_features");
|
||||
}
|
||||
@ -107,16 +117,16 @@ namespace bayesnet {
|
||||
std::unordered_set<int> featuresUsed;
|
||||
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||
int maxFeatures = 0;
|
||||
if (select_features_algorithm == "CFS") {
|
||||
if (select_features_algorithm == SelectFeatures.CFS) {
|
||||
featureSelector = new CFS(dataset, features, className, maxFeatures, states.at(className).size(), weights_);
|
||||
} else if (select_features_algorithm == "IWSS") {
|
||||
} else if (select_features_algorithm == SelectFeatures.IWSS) {
|
||||
if (threshold < 0 || threshold >0.5) {
|
||||
throw std::invalid_argument("Invalid threshold value for IWSS [0, 0.5]");
|
||||
throw std::invalid_argument("Invalid threshold value for " + SelectFeatures.IWSS + " [0, 0.5]");
|
||||
}
|
||||
featureSelector = new IWSS(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold);
|
||||
} else if (select_features_algorithm == "FCBF") {
|
||||
} else if (select_features_algorithm == SelectFeatures.FCBF) {
|
||||
if (threshold < 1e-7 || threshold > 1) {
|
||||
throw std::invalid_argument("Invalid threshold value [1e-7, 1]");
|
||||
throw std::invalid_argument("Invalid threshold value for " + SelectFeatures.FCBF + " [1e-7, 1]");
|
||||
}
|
||||
featureSelector = new FCBF(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold);
|
||||
}
|
||||
@ -174,12 +184,12 @@ namespace bayesnet {
|
||||
// n_models == maxModels
|
||||
// epsilon sub t > 0.5 => inverse the weights policy
|
||||
// validation error is not decreasing
|
||||
bool ascending = order_algorithm == "asc";
|
||||
bool ascending = order_algorithm == Orders.ASC;
|
||||
std::mt19937 g{ 173 };
|
||||
while (!exitCondition) {
|
||||
// Step 1: Build ranking with mutual information
|
||||
auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted
|
||||
if (order_algorithm == "rand") {
|
||||
if (order_algorithm == Orders.RAND) {
|
||||
std::shuffle(featureSelection.begin(), featureSelection.end(), g);
|
||||
}
|
||||
auto feature = featureSelection[0];
|
||||
|
Loading…
Reference in New Issue
Block a user