Fix tolerance hyperp error & gridsearch
This commit is contained in:
parent
460d20a402
commit
e3f6dc1e0b
@ -12,7 +12,7 @@
|
|||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
BoostAODE::BoostAODE() : Ensemble()
|
BoostAODE::BoostAODE() : Ensemble()
|
||||||
{
|
{
|
||||||
validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features" };
|
validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features", "tolerance" };
|
||||||
|
|
||||||
}
|
}
|
||||||
void BoostAODE::buildModel(const torch::Tensor& weights)
|
void BoostAODE::buildModel(const torch::Tensor& weights)
|
||||||
@ -47,22 +47,32 @@ namespace bayesnet {
|
|||||||
y_train = y_;
|
y_train = y_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void BoostAODE::setHyperparameters(const nlohmann::json& hyperparameters)
|
void BoostAODE::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||||
{
|
{
|
||||||
|
auto hyperparameters = hyperparameters_;
|
||||||
if (hyperparameters.contains("repeatSparent")) {
|
if (hyperparameters.contains("repeatSparent")) {
|
||||||
repeatSparent = hyperparameters["repeatSparent"];
|
repeatSparent = hyperparameters["repeatSparent"];
|
||||||
|
hyperparameters.erase("repeatSparent");
|
||||||
}
|
}
|
||||||
if (hyperparameters.contains("maxModels")) {
|
if (hyperparameters.contains("maxModels")) {
|
||||||
maxModels = hyperparameters["maxModels"];
|
maxModels = hyperparameters["maxModels"];
|
||||||
|
hyperparameters.erase("maxModels");
|
||||||
}
|
}
|
||||||
if (hyperparameters.contains("ascending")) {
|
if (hyperparameters.contains("ascending")) {
|
||||||
ascending = hyperparameters["ascending"];
|
ascending = hyperparameters["ascending"];
|
||||||
|
hyperparameters.erase("ascending");
|
||||||
}
|
}
|
||||||
if (hyperparameters.contains("convergence")) {
|
if (hyperparameters.contains("convergence")) {
|
||||||
convergence = hyperparameters["convergence"];
|
convergence = hyperparameters["convergence"];
|
||||||
|
hyperparameters.erase("convergence");
|
||||||
}
|
}
|
||||||
if (hyperparameters.contains("threshold")) {
|
if (hyperparameters.contains("threshold")) {
|
||||||
threshold = hyperparameters["threshold"];
|
threshold = hyperparameters["threshold"];
|
||||||
|
hyperparameters.erase("threshold");
|
||||||
|
}
|
||||||
|
if (hyperparameters.contains("tolerance")) {
|
||||||
|
tolerance = hyperparameters["tolerance"];
|
||||||
|
hyperparameters.erase("tolerance");
|
||||||
}
|
}
|
||||||
if (hyperparameters.contains("select_features")) {
|
if (hyperparameters.contains("select_features")) {
|
||||||
auto selectedAlgorithm = hyperparameters["select_features"];
|
auto selectedAlgorithm = hyperparameters["select_features"];
|
||||||
@ -72,6 +82,10 @@ namespace bayesnet {
|
|||||||
if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) {
|
if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) {
|
||||||
throw std::invalid_argument("Invalid selectFeatures value [IWSS, FCBF, CFS]");
|
throw std::invalid_argument("Invalid selectFeatures value [IWSS, FCBF, CFS]");
|
||||||
}
|
}
|
||||||
|
hyperparameters.erase("select_features");
|
||||||
|
}
|
||||||
|
if (!hyperparameters.empty()) {
|
||||||
|
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::unordered_set<int> BoostAODE::initializeModels()
|
std::unordered_set<int> BoostAODE::initializeModels()
|
||||||
@ -109,10 +123,8 @@ namespace bayesnet {
|
|||||||
void BoostAODE::trainModel(const torch::Tensor& weights)
|
void BoostAODE::trainModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
std::unordered_set<int> featuresUsed;
|
std::unordered_set<int> featuresUsed;
|
||||||
int tolerance = 5; // number of times the accuracy can be lower than the threshold
|
|
||||||
if (selectFeatures) {
|
if (selectFeatures) {
|
||||||
featuresUsed = initializeModels();
|
featuresUsed = initializeModels();
|
||||||
tolerance = 0; // Remove tolerance if features are selected
|
|
||||||
}
|
}
|
||||||
if (maxModels == 0)
|
if (maxModels == 0)
|
||||||
maxModels = .1 * n > 10 ? .1 * n : n;
|
maxModels = .1 * n > 10 ? .1 * n : n;
|
||||||
|
@ -21,6 +21,7 @@ namespace bayesnet {
|
|||||||
// Hyperparameters
|
// Hyperparameters
|
||||||
bool repeatSparent = false; // if true, a feature can be selected more than once
|
bool repeatSparent = false; // if true, a feature can be selected more than once
|
||||||
int maxModels = 0;
|
int maxModels = 0;
|
||||||
|
int tolerance = 0;
|
||||||
bool ascending = false; //Process KBest features ascending or descending order
|
bool ascending = false; //Process KBest features ascending or descending order
|
||||||
bool convergence = false; //if true, stop when the model does not improve
|
bool convergence = false; //if true, stop when the model does not improve
|
||||||
bool selectFeatures = false; // if true, use feature selection
|
bool selectFeatures = false; // if true, use feature selection
|
||||||
|
@ -89,6 +89,8 @@ namespace platform {
|
|||||||
double bestScore = 0.0;
|
double bestScore = 0.0;
|
||||||
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
||||||
auto clf = Models::instance()->create(config.model);
|
auto clf = Models::instance()->create(config.model);
|
||||||
|
auto valid = clf->getValidHyperparameters();
|
||||||
|
hyperparameters.check(valid, fileName);
|
||||||
clf->setHyperparameters(hyperparameters.get(fileName));
|
clf->setHyperparameters(hyperparameters.get(fileName));
|
||||||
auto [train, test] = fold->getFold(nfold);
|
auto [train, test] = fold->getFold(nfold);
|
||||||
auto train_t = torch::tensor(train);
|
auto train_t = torch::tensor(train);
|
||||||
|
Loading…
Reference in New Issue
Block a user