Add b_main support to grid_output files

This commit is contained in:
2024-01-15 11:53:34 +01:00
parent ecce7955f8
commit 2b20d0315c
5 changed files with 3 additions and 41 deletions

View File

@@ -177,22 +177,6 @@ void report(argparse::ArgumentParser& program)
list_results(results, config.model);
}
}
void exportResults(argparse::ArgumentParser& program)
{
// Generate a grid_<model_name>.json file with the results of the grid search
// this file can be used by b_main to run the model with the best hyperparameters
struct platform::ConfigGrid config;
config.model = program.get<std::string>("model");
auto grid_search = platform::GridSearch(config);
auto results = grid_search.loadResults();
auto output = json::array();
if (results.empty()) {
std::cout << "** No results found" << std::endl;
} else {
grid_search.exportResults(results);
std::cout << "Exported results to " << platform::Paths::grid_export(config.model) << std::endl;
}
}
void compute(argparse::ArgumentParser& program)
{
struct platform::ConfigGrid config;
@@ -256,15 +240,9 @@ int main(int argc, char** argv)
assignModel(compute_command);
add_compute_args(compute_command);
// grid export subparser
argparse::ArgumentParser export_command("export");
assignModel(export_command);
export_command.add_description("Export the computed hyperparameters to a file readable by b_main.");
program.add_subparser(dump_command);
program.add_subparser(report_command);
program.add_subparser(compute_command);
program.add_subparser(export_command);
//
// Process options
@@ -272,8 +250,7 @@ int main(int argc, char** argv)
try {
program.parse_args(argc, argv);
bool found = false;
map<std::string, void(*)(argparse::ArgumentParser&)> commands =
{ {"dump", &dump}, {"report", &report}, {"export", &exportResults}, {"compute", &compute} };
map<std::string, void(*)(argparse::ArgumentParser&)> commands = { {"dump", &dump}, {"report", &report}, {"compute", &compute} };
for (const auto& command : commands) {
if (program.is_subcommand_used(command.first)) {
std::invoke(command.second, program.at<argparse::ArgumentParser>(command.first));

View File

@@ -438,15 +438,4 @@ namespace platform {
};
file << output.dump(4);
}
void GridSearch::exportResults(json& results)
{
std::ofstream file(Paths::grid_export(config.model));
auto output = json();
for (const auto& item : results["results"].items()) {
auto key = item.key();
auto value = item.value();
output[key] = value["hyperparameters"];
}
file << output.dump(4);
}
} /* namespace platform */

View File

@@ -47,7 +47,6 @@ namespace platform {
void go(struct ConfigMPI& config_mpi);
~GridSearch() = default;
json loadResults();
void exportResults(json& results);
static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; }
private:
void save(json& results);

View File

@@ -27,7 +27,8 @@ namespace platform {
throw std::runtime_error("File " + hyperparameters_file + " not found");
}
// Check if file is a json
json input_hyperparameters = json::parse(file);
json file_hyperparameters = json::parse(file);
auto input_hyperparameters = file_hyperparameters["results"];
// Check if hyperparameters are valid
for (const auto& dataset : datasets) {
if (!input_hyperparameters.contains(dataset)) {

View File

@@ -34,10 +34,6 @@ namespace platform {
{
return grid() + "grid_" + model + "_output.json";
}
static std::string grid_export(const std::string& model)
{
return grid() + "grid_" + model + ".json";
}
};
}
#endif