Complete nested gridsearch
This commit is contained in:
parent
fb9b395748
commit
2e95e8999d
@ -172,8 +172,9 @@ namespace platform {
|
|||||||
auto states = datasets.getStates(fileName);
|
auto states = datasets.getStates(fileName);
|
||||||
auto features = datasets.getFeatures(fileName);
|
auto features = datasets.getFeatures(fileName);
|
||||||
auto className = datasets.getClassName(fileName);
|
auto className = datasets.getClassName(fileName);
|
||||||
double bestScore = 0.0;
|
int spcs_combinations = int(log(combinations.size()) / log(10)) + 1;
|
||||||
json bestHyperparameters;
|
double goatScore = 0.0;
|
||||||
|
json goatHyperparameters;
|
||||||
// for dataset // for seed // for fold // for hyperparameters // for nested fold
|
// for dataset // for seed // for fold // for hyperparameters // for nested fold
|
||||||
for (const auto& seed : config.seeds) {
|
for (const auto& seed : config.seeds) {
|
||||||
Fold* fold;
|
Fold* fold;
|
||||||
@ -182,9 +183,11 @@ namespace platform {
|
|||||||
else
|
else
|
||||||
fold = new KFold(config.n_folds, y.size(0), seed);
|
fold = new KFold(config.n_folds, y.size(0), seed);
|
||||||
double bestScore = 0.0;
|
double bestScore = 0.0;
|
||||||
|
json bestHyperparameters;
|
||||||
std::cout << "(" << seed << ") doing Fold: " << flush;
|
std::cout << "(" << seed << ") doing Fold: " << flush;
|
||||||
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
for (int nfold = 0; nfold < config.n_folds; nfold++) {
|
||||||
std::cout << nfold + 1 << flush;
|
if (!config.quiet)
|
||||||
|
std::cout << Colors::GREEN() << nfold + 1 << " " << flush;
|
||||||
// First level fold
|
// First level fold
|
||||||
auto [train, test] = fold->getFold(nfold);
|
auto [train, test] = fold->getFold(nfold);
|
||||||
auto train_t = torch::tensor(train);
|
auto train_t = torch::tensor(train);
|
||||||
@ -195,9 +198,12 @@ namespace platform {
|
|||||||
auto y_test = y.index({ test_t });
|
auto y_test = y.index({ test_t });
|
||||||
auto num = 0;
|
auto num = 0;
|
||||||
json result_fold;
|
json result_fold;
|
||||||
double score = 0.0;
|
double hypScore = 0.0;
|
||||||
|
double bestHypScore = 0.0;
|
||||||
|
json bestHypHyperparameters;
|
||||||
for (const auto& hyperparam_line : combinations) {
|
for (const auto& hyperparam_line : combinations) {
|
||||||
std::cout << "[" << ++num << "/" << combinations.size() << "] " << std::flush;
|
std::cout << "[" << setw(spcs_combinations) << ++num << "/" << setw(spcs_combinations)
|
||||||
|
<< combinations.size() << "] " << std::flush;
|
||||||
Fold* nested_fold;
|
Fold* nested_fold;
|
||||||
if (config.stratified)
|
if (config.stratified)
|
||||||
nested_fold = new StratifiedKFold(config.nested, y_train, seed);
|
nested_fold = new StratifiedKFold(config.nested, y_train, seed);
|
||||||
@ -221,21 +227,47 @@ namespace platform {
|
|||||||
// Train model
|
// Train model
|
||||||
if (!config.quiet)
|
if (!config.quiet)
|
||||||
showProgressFold(n_nested_fold + 1, getColor(clf->getStatus()), "a");
|
showProgressFold(n_nested_fold + 1, getColor(clf->getStatus()), "a");
|
||||||
// clf->fit(X_nexted_train, y_nested_train, features, className, states);
|
clf->fit(X_nexted_train, y_nested_train, features, className, states);
|
||||||
// Test model
|
// Test model
|
||||||
if (!config.quiet)
|
if (!config.quiet)
|
||||||
showProgressFold(n_nested_fold + 1, getColor(clf->getStatus()), "b");
|
showProgressFold(n_nested_fold + 1, getColor(clf->getStatus()), "b");
|
||||||
// score += clf->score(X_nested_test, y_nested_test);
|
hypScore += clf->score(X_nested_test, y_nested_test);
|
||||||
score = 0.0;
|
if (!config.quiet)
|
||||||
|
std::cout << "\b\b\b, \b" << flush;
|
||||||
}
|
}
|
||||||
|
std::cout << string(3 * config.nested + 2 * spcs_combinations + 4, '\b') << flush;
|
||||||
delete nested_fold;
|
delete nested_fold;
|
||||||
score = score / config.nested;
|
hypScore /= config.nested;
|
||||||
|
if (hypScore > bestHypScore) {
|
||||||
|
bestHypScore = hypScore;
|
||||||
|
bestHypHyperparameters = hyperparam_line;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
// Build Classifier with selected hyperparameters
|
||||||
|
auto clf = Models::instance()->create(config.model);
|
||||||
|
clf->setHyperparameters(bestHypHyperparameters);
|
||||||
|
// Train model
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "a");
|
||||||
|
clf->fit(X_train, y_train, features, className, states);
|
||||||
|
// Test model
|
||||||
|
if (!config.quiet)
|
||||||
|
showProgressFold(nfold + 1, getColor(clf->getStatus()), "b");
|
||||||
|
double score = clf->score(X_test, y_test);
|
||||||
|
if (!config.quiet)
|
||||||
|
std::cout << "\b\b\b\b\b, \b" << flush;
|
||||||
|
if (score > bestScore) {
|
||||||
|
bestScore = score;
|
||||||
|
bestHyperparameters = bestHypHyperparameters;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bestScore > goatScore) {
|
||||||
|
goatScore = bestScore;
|
||||||
|
goatHyperparameters = bestHyperparameters;
|
||||||
}
|
}
|
||||||
delete fold;
|
delete fold;
|
||||||
}
|
}
|
||||||
return { bestScore, bestHyperparameters };
|
return { goatScore, goatHyperparameters };
|
||||||
}
|
}
|
||||||
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
vector<std::string> GridSearch::processDatasets(Datasets& datasets)
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user