Fix tolerance hyperp error & gridsearch

This commit is contained in:
2023-11-29 12:33:50 +01:00
parent 460d20a402
commit e3f6dc1e0b
3 changed files with 19 additions and 4 deletions

View File

@@ -89,6 +89,8 @@ namespace platform {
double bestScore = 0.0;
for (int nfold = 0; nfold < config.n_folds; nfold++) {
auto clf = Models::instance()->create(config.model);
auto valid = clf->getValidHyperparameters();
hyperparameters.check(valid, fileName);
clf->setHyperparameters(hyperparameters.get(fileName));
auto [train, test] = fold->getFold(nfold);
auto train_t = torch::tensor(train);