From 2b20d0315cb7a806d56e50687c17f13354325691 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 15 Jan 2024 11:53:34 +0100 Subject: [PATCH] Add b_main support to grid_output files --- src/Platform/b_grid.cc | 25 +------------------------ src/Platform/modules/GridSearch.cc | 11 ----------- src/Platform/modules/GridSearch.h | 1 - src/Platform/modules/HyperParameters.cc | 3 ++- src/Platform/modules/Paths.h | 4 ---- 5 files changed, 3 insertions(+), 41 deletions(-) diff --git a/src/Platform/b_grid.cc b/src/Platform/b_grid.cc index 8c841b1..c344bf6 100644 --- a/src/Platform/b_grid.cc +++ b/src/Platform/b_grid.cc @@ -177,22 +177,6 @@ void report(argparse::ArgumentParser& program) list_results(results, config.model); } } -void exportResults(argparse::ArgumentParser& program) -{ - // Generate a grid_.json file with the results of the grid search - // this file can be used by b_main to run the model with the best hyperparameters - struct platform::ConfigGrid config; - config.model = program.get("model"); - auto grid_search = platform::GridSearch(config); - auto results = grid_search.loadResults(); - auto output = json::array(); - if (results.empty()) { - std::cout << "** No results found" << std::endl; - } else { - grid_search.exportResults(results); - std::cout << "Exported results to " << platform::Paths::grid_export(config.model) << std::endl; - } -} void compute(argparse::ArgumentParser& program) { struct platform::ConfigGrid config; @@ -256,15 +240,9 @@ int main(int argc, char** argv) assignModel(compute_command); add_compute_args(compute_command); - // grid export subparser - argparse::ArgumentParser export_command("export"); - assignModel(export_command); - export_command.add_description("Export the computed hyperparameters to a file readable by b_main."); - program.add_subparser(dump_command); program.add_subparser(report_command); program.add_subparser(compute_command); - program.add_subparser(export_command); // // Process options @@ -272,8 +250,7 @@ int main(int argc, char** argv) try { program.parse_args(argc, argv); bool found = false; - map commands = - { {"dump", &dump}, {"report", &report}, {"export", &exportResults}, {"compute", &compute} }; + map commands = { {"dump", &dump}, {"report", &report}, {"compute", &compute} }; for (const auto& command : commands) { if (program.is_subcommand_used(command.first)) { std::invoke(command.second, program.at(command.first)); diff --git a/src/Platform/modules/GridSearch.cc b/src/Platform/modules/GridSearch.cc index daff8b1..2ef0ad1 100644 --- a/src/Platform/modules/GridSearch.cc +++ b/src/Platform/modules/GridSearch.cc @@ -438,15 +438,4 @@ namespace platform { }; file << output.dump(4); } - void GridSearch::exportResults(json& results) - { - std::ofstream file(Paths::grid_export(config.model)); - auto output = json(); - for (const auto& item : results["results"].items()) { - auto key = item.key(); - auto value = item.value(); - output[key] = value["hyperparameters"]; - } - file << output.dump(4); - } } /* namespace platform */ \ No newline at end of file diff --git a/src/Platform/modules/GridSearch.h b/src/Platform/modules/GridSearch.h index e665519..ec1b3cb 100644 --- a/src/Platform/modules/GridSearch.h +++ b/src/Platform/modules/GridSearch.h @@ -47,7 +47,6 @@ namespace platform { void go(struct ConfigMPI& config_mpi); ~GridSearch() = default; json loadResults(); - void exportResults(json& results); static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; } private: void save(json& results); diff --git a/src/Platform/modules/HyperParameters.cc b/src/Platform/modules/HyperParameters.cc index 92d56d7..530ac03 100644 --- a/src/Platform/modules/HyperParameters.cc +++ b/src/Platform/modules/HyperParameters.cc @@ -27,7 +27,8 @@ namespace platform { throw std::runtime_error("File " + hyperparameters_file + " not found"); } // Check if file is a json - json input_hyperparameters = json::parse(file); + json file_hyperparameters = json::parse(file); + auto input_hyperparameters = file_hyperparameters["results"]; // Check if hyperparameters are valid for (const auto& dataset : datasets) { if (!input_hyperparameters.contains(dataset)) { diff --git a/src/Platform/modules/Paths.h b/src/Platform/modules/Paths.h index 8d19040..6fd61cf 100644 --- a/src/Platform/modules/Paths.h +++ b/src/Platform/modules/Paths.h @@ -34,10 +34,6 @@ namespace platform { { return grid() + "grid_" + model + "_output.json"; } - static std::string grid_export(const std::string& model) - { - return grid() + "grid_" + model + ".json"; - } }; } #endif \ No newline at end of file