refactor gridsearch to have only one go method

This commit is contained in:
2023-12-02 10:59:05 +01:00
parent 33cd32c639
commit 03e4437fea
5 changed files with 176 additions and 137 deletions

View File

@@ -33,7 +33,7 @@ void manageArguments(argparse::ArgumentParser& program)
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--continue").help("Continue computing from that dataset").default_value("No");
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
@@ -95,7 +95,7 @@ void list_results(json& results, std::string& model)
{
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
std::cout << headerLine("Listing computed hyperparameters for model " + model);
std::cout << headerLine("Date & time: " + results["date"].get<std::string>());
std::cout << headerLine("Date & time: " + results["date"].get<std::string>() + " Duration: " + results["duration"].get<std::string>());
std::cout << headerLine("Score: " + results["score"].get<std::string>());
std::cout << headerLine(
"Random seeds: " + results["seeds"].dump()
@@ -118,9 +118,9 @@ void list_results(json& results, std::string& model)
}
}
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
<< "Duration " << setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
<< string(hyperparameters_spaces, '=') << std::endl;
<< string(8, '=') << " " << string(hyperparameters_spaces, '=') << std::endl;
bool odd = true;
int index = 0;
for (const auto& item : results["results"].items()) {
@@ -130,8 +130,8 @@ void list_results(json& results, std::string& model)
std::cout << color;
std::cout << std::setw(3) << std::right << index++ << " ";
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
<< " " << setw(8) << setprecision(6) << fixed << right
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
<< " " << setw(8) << value["duration"] << " " << setw(8) << setprecision(6)
<< fixed << right << value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
odd = !odd;
}
std::cout << Colors::RESET() << std::endl;
@@ -159,12 +159,12 @@ int main(int argc, char** argv)
config.seeds = program.get<std::vector<int>>("seeds");
config.nested = program.get<int>("nested");
config.continue_from = program.get<std::string>("continue");
if (config.continue_from == "No" && config.only) {
if (config.continue_from == platform::GridSearch::NO_CONTINUE() && config.only) {
throw std::runtime_error("Cannot use --only without --continue");
}
dump = program.get<bool>("dump");
compute = program.get<bool>("compute");
if (dump && (config.continue_from != "No" || config.only)) {
if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
throw std::runtime_error("Cannot use --dump with --continue or --only");
}
}
@@ -177,6 +177,7 @@ int main(int argc, char** argv)
* Begin Processing
*/
auto env = platform::DotEnv();
config.platform = env.get("platform");
platform::Paths::createPath(platform::Paths::grid());
auto grid_search = platform::GridSearch(config);
platform::Timer timer;
@@ -185,10 +186,7 @@ int main(int argc, char** argv)
list_dump(config.model);
} else {
if (compute) {
if (config.nested == 0)
grid_search.goSingle();
else
grid_search.goNested();
grid_search.go();
std::cout << "Process took " << timer.getDurationString() << std::endl;
} else {
// List results