Fix dataset name order in grid experiment

This commit is contained in:
2025-01-17 16:58:39 +01:00
parent d0e65348e0
commit eb430a84c4
3 changed files with 12 additions and 10 deletions

View File

@@ -101,14 +101,14 @@ void add_experiment_args(argparse::ArgumentParser& program)
auto seed_values = env.getSeeds();
program.add_argument("--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
}
void add_compute_args(argparse::ArgumentParser& program)
void add_search_args(argparse::ArgumentParser& program)
{
auto env = platform::DotEnv();
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE());
program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true);
program.add_argument("--only").help("Used with continue to search with that dataset only").default_value(false).implicit_value(true);
program.add_argument("--exclude").default_value("[]").help("Datasets to exclude in json format, e.g. [\"dataset1\", \"dataset2\"]");
auto valid_choices = env.valid_tokens("smooth_strat");
auto& smooth_arg = program.add_argument("--smooth-strat").help("Smooth strategy used in Bayes Network node initialization. Valid values: " + env.valid_values("smooth_strat")).default_value(env.get("smooth_strat"));
@@ -288,7 +288,7 @@ void search(argparse::ArgumentParser& program)
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_config.rank);
MPI_Comm_size(MPI_COMM_WORLD, &mpi_config.n_procs);
if (mpi_config.n_procs < 2) {
throw std::runtime_error("Cannot use --compute with less than 2 mpi processes, try mpirun -np 2 ...");
throw std::runtime_error("Cannot use --search with less than 2 mpi processes, try mpirun -np 2 ...");
}
grid_search.go(mpi_config);
if (mpi_config.rank == mpi_config.manager) {
@@ -314,7 +314,7 @@ void experiment(argparse::ArgumentParser& program)
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_config.rank);
MPI_Comm_size(MPI_COMM_WORLD, &mpi_config.n_procs);
if (mpi_config.n_procs < 2) {
throw std::runtime_error("Cannot use --compute with less than 2 mpi processes, try mpirun -np 2 ...");
throw std::runtime_error("Cannot use --experiment with less than 2 mpi processes, try mpirun -np 2 ...");
}
grid_experiment.go(mpi_config);
if (mpi_config.rank == mpi_config.manager) {
@@ -322,8 +322,8 @@ void experiment(argparse::ArgumentParser& program)
std::cout << "* Report of the computed hyperparameters" << std::endl;
auto duration = timer.getDuration();
experiment.setDuration(duration);
// experiment.report(grid_experiment.numFiles() == 1);
experiment.saveResult();
experiment.report(grid_experiment.numFiles() == 1);
std::cout << "Process took " << duration << std::endl;
}
MPI_Finalize();
@@ -344,11 +344,11 @@ int main(int argc, char** argv)
assignModel(report_command);
report_command.add_description("Report the computed hyperparameters of a model.");
// grid compute subparser
// grid search subparser
argparse::ArgumentParser search_command("search");
search_command.add_description("Search using mpi the hyperparameters of a model.");
assignModel(search_command);
add_compute_args(search_command);
add_search_args(search_command);
// grid experiment subparser
argparse::ArgumentParser experiment_command("experiment");
@@ -376,7 +376,7 @@ int main(int argc, char** argv)
}
}
if (!found) {
throw std::runtime_error("You must specify one of the following commands: dump, report, compute\n");
throw std::runtime_error("You must specify one of the following commands: dump, experiment, report, search \n");
}
}
catch (const exception& err) {

View File

@@ -95,7 +95,6 @@ namespace platform {
}
}
}
platform::HyperParameters test_hyperparams;
if (hyperparameters_file != "") {
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file, hyper_best);
@@ -109,6 +108,7 @@ namespace platform {
this->config.smooth_strategy = smooth_strat;
this->config.n_folds = n_folds;
this->config.seeds = seeds;
this->config.quiet = false;
auto env = platform::DotEnv();
experiment.setTitle(title).setLanguage("c++").setLanguageVersion("gcc 14.1.1");
experiment.setDiscretizationAlgorithm(discretize_algo).setSmoothSrategy(smooth_strat);
@@ -138,6 +138,8 @@ namespace platform {
void GridExperiment::compile_results(json& results, json& all_results, std::string& model)
{
auto datasets = Datasets(false, Paths::datasets());
nlohmann::json temp = all_results; // To restore the order of the data by dataset name
all_results = temp;
for (const auto& result_item : all_results.items()) {
// each result has the results of all the outer folds as each one were a different task
auto dataset_name = result_item.key();

View File

@@ -224,7 +224,7 @@ namespace platform {
std::string ReportConsole::buildClassificationReport(json& result, std::string color)
{
std::stringstream oss;
if (result.find("confusion_matrices") == result.end())
if (result.find("confusion_matrices") == result.end() || result["confusion_matrices"].size() == 0)
return "";
bool second_header = false;
int lines_header = 0;