Add b_main support to grid_output files
This commit is contained in:
@@ -177,22 +177,6 @@ void report(argparse::ArgumentParser& program)
|
||||
list_results(results, config.model);
|
||||
}
|
||||
}
|
||||
void exportResults(argparse::ArgumentParser& program)
|
||||
{
|
||||
// Generate a grid_<model_name>.json file with the results of the grid search
|
||||
// this file can be used by b_main to run the model with the best hyperparameters
|
||||
struct platform::ConfigGrid config;
|
||||
config.model = program.get<std::string>("model");
|
||||
auto grid_search = platform::GridSearch(config);
|
||||
auto results = grid_search.loadResults();
|
||||
auto output = json::array();
|
||||
if (results.empty()) {
|
||||
std::cout << "** No results found" << std::endl;
|
||||
} else {
|
||||
grid_search.exportResults(results);
|
||||
std::cout << "Exported results to " << platform::Paths::grid_export(config.model) << std::endl;
|
||||
}
|
||||
}
|
||||
void compute(argparse::ArgumentParser& program)
|
||||
{
|
||||
struct platform::ConfigGrid config;
|
||||
@@ -256,15 +240,9 @@ int main(int argc, char** argv)
|
||||
assignModel(compute_command);
|
||||
add_compute_args(compute_command);
|
||||
|
||||
// grid export subparser
|
||||
argparse::ArgumentParser export_command("export");
|
||||
assignModel(export_command);
|
||||
export_command.add_description("Export the computed hyperparameters to a file readable by b_main.");
|
||||
|
||||
program.add_subparser(dump_command);
|
||||
program.add_subparser(report_command);
|
||||
program.add_subparser(compute_command);
|
||||
program.add_subparser(export_command);
|
||||
|
||||
//
|
||||
// Process options
|
||||
@@ -272,8 +250,7 @@ int main(int argc, char** argv)
|
||||
try {
|
||||
program.parse_args(argc, argv);
|
||||
bool found = false;
|
||||
map<std::string, void(*)(argparse::ArgumentParser&)> commands =
|
||||
{ {"dump", &dump}, {"report", &report}, {"export", &exportResults}, {"compute", &compute} };
|
||||
map<std::string, void(*)(argparse::ArgumentParser&)> commands = { {"dump", &dump}, {"report", &report}, {"compute", &compute} };
|
||||
for (const auto& command : commands) {
|
||||
if (program.is_subcommand_used(command.first)) {
|
||||
std::invoke(command.second, program.at<argparse::ArgumentParser>(command.first));
|
||||
|
@@ -438,15 +438,4 @@ namespace platform {
|
||||
};
|
||||
file << output.dump(4);
|
||||
}
|
||||
void GridSearch::exportResults(json& results)
|
||||
{
|
||||
std::ofstream file(Paths::grid_export(config.model));
|
||||
auto output = json();
|
||||
for (const auto& item : results["results"].items()) {
|
||||
auto key = item.key();
|
||||
auto value = item.value();
|
||||
output[key] = value["hyperparameters"];
|
||||
}
|
||||
file << output.dump(4);
|
||||
}
|
||||
} /* namespace platform */
|
@@ -47,7 +47,6 @@ namespace platform {
|
||||
void go(struct ConfigMPI& config_mpi);
|
||||
~GridSearch() = default;
|
||||
json loadResults();
|
||||
void exportResults(json& results);
|
||||
static inline std::string NO_CONTINUE() { return "NO_CONTINUE"; }
|
||||
private:
|
||||
void save(json& results);
|
||||
|
@@ -27,7 +27,8 @@ namespace platform {
|
||||
throw std::runtime_error("File " + hyperparameters_file + " not found");
|
||||
}
|
||||
// Check if file is a json
|
||||
json input_hyperparameters = json::parse(file);
|
||||
json file_hyperparameters = json::parse(file);
|
||||
auto input_hyperparameters = file_hyperparameters["results"];
|
||||
// Check if hyperparameters are valid
|
||||
for (const auto& dataset : datasets) {
|
||||
if (!input_hyperparameters.contains(dataset)) {
|
||||
|
@@ -34,10 +34,6 @@ namespace platform {
|
||||
{
|
||||
return grid() + "grid_" + model + "_output.json";
|
||||
}
|
||||
static std::string grid_export(const std::string& model)
|
||||
{
|
||||
return grid() + "grid_" + model + ".json";
|
||||
}
|
||||
};
|
||||
}
|
||||
#endif
|
Reference in New Issue
Block a user