Continue grid Experiment

This commit is contained in:
2025-01-14 22:04:23 +01:00
parent 386faf960e
commit 9a9a9fb17a
7 changed files with 226 additions and 352 deletions

View File

@@ -36,22 +36,22 @@ void add_experiment_args(argparse::ArgumentParser& program)
{
auto env = platform::DotEnv();
auto datasets = platform::Datasets(false, platform::Paths::datasets());
auto& group = program.add_mutually_exclusive_group(true);
group.add_argument("-d", "--dataset")
.help("Dataset file name: " + datasets.toString())
.default_value("all")
.action([](const std::string& value) {
auto datasets = platform::Datasets(false, platform::Paths::datasets());
static std::vector<std::string> choices_datasets(datasets.getNames());
choices_datasets.push_back("all");
if (find(choices_datasets.begin(), choices_datasets.end(), value) != choices_datasets.end()) {
return value;
}
throw std::runtime_error("Dataset must be one of: " + datasets.toString());
}
);
group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector<std::string>());
group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test.");
// auto& group = program.add_mutually_exclusive_group(true);
// group.add_argument("-d", "--dataset")
// .help("Dataset file name: " + datasets.toString())
// .default_value("all")
// .action([](const std::string& value) {
// auto datasets = platform::Datasets(false, platform::Paths::datasets());
// static std::vector<std::string> choices_datasets(datasets.getNames());
// choices_datasets.push_back("all");
// if (find(choices_datasets.begin(), choices_datasets.end(), value) != choices_datasets.end()) {
// return value;
// }
// throw std::runtime_error("Dataset must be one of: " + datasets.toString());
// }
// );
// group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector<std::string>());
// group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test.");
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
@@ -261,7 +261,7 @@ void report(argparse::ArgumentParser& program)
list_results(results, config.model);
}
}
void compute(argparse::ArgumentParser& program)
void search(argparse::ArgumentParser& program)
{
struct platform::ConfigGrid config;
config.model = program.get<std::string>("model");
@@ -298,6 +298,7 @@ void compute(argparse::ArgumentParser& program)
grid_search.go(mpi_config);
if (mpi_config.rank == mpi_config.manager) {
auto results = grid_search.loadResults();
std::cout << Colors::RESET() << "* Report of the computed hyperparameters" << std::endl;
list_results(results, config.model);
std::cout << "Process took " << timer.getDurationString() << std::endl;
}
@@ -331,7 +332,9 @@ void experiment(argparse::ArgumentParser& program)
}
grid_experiment.go(mpi_config);
if (mpi_config.rank == mpi_config.manager) {
// auto results = grid_experiment.loadResults();
auto results = grid_experiment.getResults();
std::cout << "****** RESULTS ********" << std::endl;
std::cout << results.dump(4) << std::endl;
// list_results(results, config.model);
std::cout << "Process took " << timer.getDurationString() << std::endl;
}
@@ -354,10 +357,10 @@ int main(int argc, char** argv)
report_command.add_description("Report the computed hyperparameters of a model.");
// grid compute subparser
argparse::ArgumentParser compute_command("compute");
compute_command.add_description("Compute using mpi the hyperparameters of a model.");
assignModel(compute_command);
add_compute_args(compute_command);
argparse::ArgumentParser search_command("search");
search_command.add_description("Search using mpi the hyperparameters of a model.");
assignModel(search_command);
add_compute_args(search_command);
// grid experiment subparser
argparse::ArgumentParser experiment_command("experiment");
@@ -367,7 +370,7 @@ int main(int argc, char** argv)
program.add_subparser(dump_command);
program.add_subparser(report_command);
program.add_subparser(compute_command);
program.add_subparser(search_command);
program.add_subparser(experiment_command);
//
@@ -376,7 +379,7 @@ int main(int argc, char** argv)
try {
program.parse_args(argc, argv);
bool found = false;
map<std::string, void(*)(argparse::ArgumentParser&)> commands = { {"dump", &dump}, {"report", &report}, {"compute", &compute}, { "experiment",&experiment } };
map<std::string, void(*)(argparse::ArgumentParser&)> commands = { {"dump", &dump}, {"report", &report}, {"search", &search}, { "experiment",&experiment } };
for (const auto& command : commands) {
if (program.is_subcommand_used(command.first)) {
std::invoke(command.second, program.at<argparse::ArgumentParser>(command.first));