refactor gridsearch to have only one go method
This commit is contained in:
@@ -33,7 +33,7 @@ void manageArguments(argparse::ArgumentParser& program)
|
||||
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
||||
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
||||
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
||||
program.add_argument("--continue").help("Continue computing from that dataset").default_value("No");
|
||||
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
|
||||
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
|
||||
program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>();
|
||||
program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy");
|
||||
@@ -95,7 +95,7 @@ void list_results(json& results, std::string& model)
|
||||
{
|
||||
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
|
||||
std::cout << headerLine("Listing computed hyperparameters for model " + model);
|
||||
std::cout << headerLine("Date & time: " + results["date"].get<std::string>());
|
||||
std::cout << headerLine("Date & time: " + results["date"].get<std::string>() + " Duration: " + results["duration"].get<std::string>());
|
||||
std::cout << headerLine("Score: " + results["score"].get<std::string>());
|
||||
std::cout << headerLine(
|
||||
"Random seeds: " + results["seeds"].dump()
|
||||
@@ -118,9 +118,9 @@ void list_results(json& results, std::string& model)
|
||||
}
|
||||
}
|
||||
std::cout << Colors::GREEN() << " # " << left << setw(spaces) << "Dataset" << " " << setw(19) << "Date" << " "
|
||||
<< setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
|
||||
<< "Duration " << setw(8) << "Score" << " " << "Hyperparameters" << std::endl;
|
||||
std::cout << "=== " << string(spaces, '=') << " " << string(19, '=') << " " << string(8, '=') << " "
|
||||
<< string(hyperparameters_spaces, '=') << std::endl;
|
||||
<< string(8, '=') << " " << string(hyperparameters_spaces, '=') << std::endl;
|
||||
bool odd = true;
|
||||
int index = 0;
|
||||
for (const auto& item : results["results"].items()) {
|
||||
@@ -130,8 +130,8 @@ void list_results(json& results, std::string& model)
|
||||
std::cout << color;
|
||||
std::cout << std::setw(3) << std::right << index++ << " ";
|
||||
std::cout << left << setw(spaces) << key << " " << value["date"].get<string>()
|
||||
<< " " << setw(8) << setprecision(6) << fixed << right
|
||||
<< value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
|
||||
<< " " << setw(8) << value["duration"] << " " << setw(8) << setprecision(6)
|
||||
<< fixed << right << value["score"].get<double>() << " " << value["hyperparameters"].dump() << std::endl;
|
||||
odd = !odd;
|
||||
}
|
||||
std::cout << Colors::RESET() << std::endl;
|
||||
@@ -159,12 +159,12 @@ int main(int argc, char** argv)
|
||||
config.seeds = program.get<std::vector<int>>("seeds");
|
||||
config.nested = program.get<int>("nested");
|
||||
config.continue_from = program.get<std::string>("continue");
|
||||
if (config.continue_from == "No" && config.only) {
|
||||
if (config.continue_from == platform::GridSearch::NO_CONTINUE() && config.only) {
|
||||
throw std::runtime_error("Cannot use --only without --continue");
|
||||
}
|
||||
dump = program.get<bool>("dump");
|
||||
compute = program.get<bool>("compute");
|
||||
if (dump && (config.continue_from != "No" || config.only)) {
|
||||
if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
|
||||
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
||||
}
|
||||
}
|
||||
@@ -177,6 +177,7 @@ int main(int argc, char** argv)
|
||||
* Begin Processing
|
||||
*/
|
||||
auto env = platform::DotEnv();
|
||||
config.platform = env.get("platform");
|
||||
platform::Paths::createPath(platform::Paths::grid());
|
||||
auto grid_search = platform::GridSearch(config);
|
||||
platform::Timer timer;
|
||||
@@ -185,10 +186,7 @@ int main(int argc, char** argv)
|
||||
list_dump(config.model);
|
||||
} else {
|
||||
if (compute) {
|
||||
if (config.nested == 0)
|
||||
grid_search.goSingle();
|
||||
else
|
||||
grid_search.goNested();
|
||||
grid_search.go();
|
||||
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
||||
} else {
|
||||
// List results
|
||||
|
Reference in New Issue
Block a user