diff --git a/src/Platform/GridSearch.cc b/src/Platform/GridSearch.cc index e6b25af..20fed05 100644 --- a/src/Platform/GridSearch.cc +++ b/src/Platform/GridSearch.cc @@ -63,7 +63,7 @@ namespace platform { return Colors::RESET(); } } - double GridSearch::processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters) + double GridSearch::processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters) { // Get dataset auto [X, y] = datasets.getTensors(fileName); @@ -135,11 +135,8 @@ namespace platform { } return datasets_names; } - - void GridSearch::go() + json GridSearch::initializeResults() { - auto datasets = Datasets(config.discretize, Paths::datasets()); - auto datasets_names = processDatasets(datasets); // Load previous results json results; if (config.continue_from != "No") { @@ -149,6 +146,7 @@ namespace platform { std::ifstream file(Paths::grid_output(config.model)); if (file.is_open()) { results = json::parse(file); + results = results["results"]; } } catch (const std::exception& e) { @@ -157,7 +155,15 @@ namespace platform { results = json(); } } - std::cout << "***************** Starting Gridsearch *****************" << std::endl; + return results; + } + + void GridSearch::goSingle() + { + auto datasets = Datasets(config.discretize, Paths::datasets()); + auto datasets_names = processDatasets(datasets); + json results = initializeResults(); + std::cout << "***************** Starting Single Gridsearch *****************" << std::endl; std::cout << "input file=" << Paths::grid_input(config.model) << std::endl; auto grid = GridData(Paths::grid_input(config.model)); // Generate hyperparameters grid & run gridsearch @@ -174,7 +180,7 @@ namespace platform { if (!config.quiet) showProgressComb(++num, totalComb, Colors::CYAN()); auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); - double score = processFile(dataset, datasets, hyperparameters); + double score = processFileSingle(dataset, datasets, hyperparameters); if (score > bestScore) { bestScore = score; bestHyperparameters = hyperparam_line; @@ -184,20 +190,80 @@ namespace platform { std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed << bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl; } - results[dataset]["score"] = bestScore; - results[dataset]["hyperparameters"] = bestHyperparameters; - results[dataset]["date"] = get_date() + " " + get_time(); - results[dataset]["grid"] = grid.getInputGrid(dataset); + json result = { + { "score", bestScore }, + { "hyperparameters", bestHyperparameters }, + { "date", get_date() + " " + get_time() }, + { "grid", grid.getInputGrid(dataset) } + }; + results[dataset] = result; // Save partial results save(results); } // Save final results save(results); - std::cout << "***************** Ending Gridsearch *******************" << std::endl; + std::cout << "***************** Ending Single Gridsearch *******************" << std::endl; + } + void GridSearch::goNested() + { + auto datasets = Datasets(config.discretize, Paths::datasets()); + auto datasets_names = processDatasets(datasets); + json results = initializeResults(); + std::cout << "***************** Starting Nested Gridsearch *****************" << std::endl; + std::cout << "input file=" << Paths::grid_input(config.model) << std::endl; + auto grid = GridData(Paths::grid_input(config.model)); + // Generate hyperparameters grid & run gridsearch + // Check each combination of hyperparameters for each dataset and each seed + for (const auto& dataset : datasets_names) { + auto totalComb = grid.getNumCombinations(dataset); + if (!config.quiet) + std::cout << "- " << setw(20) << left << dataset << " " << right << flush; + int num = 0; + double bestScore = 0.0; + json bestHyperparameters; + auto combinations = grid.getGrid(dataset); + for (const auto& hyperparam_line : combinations) { + if (!config.quiet) + showProgressComb(++num, totalComb, Colors::CYAN()); + auto hyperparameters = platform::HyperParameters(datasets.getNames(), hyperparam_line); + double score = processFileSingle(dataset, datasets, hyperparameters); + if (score > bestScore) { + bestScore = score; + bestHyperparameters = hyperparam_line; + } + } + if (!config.quiet) { + std::cout << "end." << " Score: " << setw(9) << setprecision(7) << fixed + << bestScore << " [" << bestHyperparameters.dump() << "]" << std::endl; + } + json result = { + { "score", bestScore }, + { "hyperparameters", bestHyperparameters }, + { "date", get_date() + " " + get_time() }, + { "grid", grid.getInputGrid(dataset) } + }; + results[dataset] = result; + // Save partial results + save(results); + } + // Save final results + save(results); + std::cout << "***************** Ending Nested Gridsearch *******************" << std::endl; } void GridSearch::save(json& results) const { std::ofstream file(Paths::grid_output(config.model)); - file << results.dump(4); + json output = { + { "model", config.model }, + { "score", config.score }, + { "discretize", config.discretize }, + { "stratified", config.stratified }, + { "n_folds", config.n_folds }, + { "seeds", config.seeds }, + { "date", get_date() + " " + get_time()}, + { "nested", config.nested}, + { "results", results } + }; + file << output.dump(4); } } /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/GridSearch.h b/src/Platform/GridSearch.h index e8bfb85..d27c59d 100644 --- a/src/Platform/GridSearch.h +++ b/src/Platform/GridSearch.h @@ -17,19 +17,22 @@ namespace platform { bool only; // used with continue_from to only compute that dataset bool discretize; bool stratified; + int nested; int n_folds; std::vector seeds; }; class GridSearch { public: explicit GridSearch(struct ConfigGrid& config); - void go(); + void goSingle(); + void goNested(); ~GridSearch() = default; json getResults(); private: void save(json& results) const; + json initializeResults(); vector processDatasets(Datasets& datasets); - double processFile(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); + double processFileSingle(std::string fileName, Datasets& datasets, HyperParameters& hyperparameters); struct ConfigGrid config; }; } /* namespace platform */ diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index 64452ed..f190c5d 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -11,6 +11,7 @@ #include "Colors.h" using json = nlohmann::json; +const int MAXL = 133; void manageArguments(argparse::ArgumentParser& program) { @@ -27,13 +28,14 @@ void manageArguments(argparse::ArgumentParser& program) } ); group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true); - group.add_argument("--list").help("List the computed hyperparameters").default_value(false).implicit_value(true); + group.add_argument("--report").help("Report the computed hyperparameters").default_value(false).implicit_value(true); group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--continue").help("Continue computing from that dataset").default_value("No"); program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true); + program.add_argument("--nested").help("Do a double/nested cross validation with n folds").default_value(0).scan<'i', int>(); program.add_argument("--score").help("Score used in gridsearch").default_value("accuracy"); program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) { try { @@ -83,13 +85,29 @@ void list_dump(std::string& model) } std::cout << Colors::RESET() << std::endl; } +std::string headerLine(const std::string& text, int utf = 0) +{ + int n = MAXL - text.length() - 3; + n = n < 0 ? 0 : n; + return "* " + text + std::string(n + utf, ' ') + "*\n"; +} void list_results(json& results, std::string& model) { - std::cout << Colors::MAGENTA() << "Listing computed hyperparameters for model " - << model << std::endl << std::endl; + std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl; + std::cout << headerLine("Listing computed hyperparameters for model " + model); + std::cout << headerLine("Date & time: " + results["date"].get()); + std::cout << headerLine("Score: " + results["score"].get()); + std::cout << headerLine( + "Random seeds: " + results["seeds"].dump() + + " Discretized: " + (results["discretize"].get() ? "True" : "False") + + " Stratified: " + (results["stratified"].get() ? "True" : "False") + + " #Folds: " + std::to_string(results["n_folds"].get()) + + " Nested: " + (results["nested"].get() == 0 ? "False" : to_string(results["nested"].get())) + ); + std::cout << std::string(MAXL, '*') << std::endl; int spaces = 0; int hyperparameters_spaces = 0; - for (const auto& item : results.items()) { + for (const auto& item : results["results"].items()) { auto key = item.key(); auto value = item.value(); if (key.size() > spaces) { @@ -105,7 +123,7 @@ void list_results(json& results, std::string& model) << string(hyperparameters_spaces, '=') << std::endl; bool odd = true; int index = 0; - for (const auto& item : results.items()) { + for (const auto& item : results["results"].items()) { auto color = odd ? Colors::CYAN() : Colors::BLUE(); auto key = item.key(); auto value = item.value(); @@ -119,12 +137,16 @@ void list_results(json& results, std::string& model) std::cout << Colors::RESET() << std::endl; } + +/* + * Main + */ int main(int argc, char** argv) { argparse::ArgumentParser program("b_grid"); manageArguments(program); struct platform::ConfigGrid config; - bool dump, compute, list; + bool dump, compute; try { program.parse_args(argc, argv); config.model = program.get("model"); @@ -135,13 +157,13 @@ int main(int argc, char** argv) config.quiet = program.get("quiet"); config.only = program.get("only"); config.seeds = program.get>("seeds"); + config.nested = program.get("nested"); config.continue_from = program.get("continue"); if (config.continue_from == "No" && config.only) { throw std::runtime_error("Cannot use --only without --continue"); } dump = program.get("dump"); compute = program.get("compute"); - list = program.get("list"); if (dump && (config.continue_from != "No" || config.only)) { throw std::runtime_error("Cannot use --dump with --continue or --only"); } @@ -163,7 +185,10 @@ int main(int argc, char** argv) list_dump(config.model); } else { if (compute) { - grid_search.go(); + if (config.nested == 0) + grid_search.goSingle(); + else + grid_search.goNested(); std::cout << "Process took " << timer.getDurationString() << std::endl; } else { // List results