diff --git a/src/commands/b_grid.cpp b/src/commands/b_grid.cpp index 1a0281a..94cdcee 100644 --- a/src/commands/b_grid.cpp +++ b/src/commands/b_grid.cpp @@ -101,14 +101,14 @@ void add_experiment_args(argparse::ArgumentParser& program) auto seed_values = env.getSeeds(); program.add_argument("--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values); } -void add_compute_args(argparse::ArgumentParser& program) +void add_search_args(argparse::ArgumentParser& program) { auto env = platform::DotEnv(); program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--continue").help("Continue computing from that dataset").default_value(platform::GridSearch::NO_CONTINUE()); - program.add_argument("--only").help("Used with continue to compute that dataset only").default_value(false).implicit_value(true); + program.add_argument("--only").help("Used with continue to search with that dataset only").default_value(false).implicit_value(true); program.add_argument("--exclude").default_value("[]").help("Datasets to exclude in json format, e.g. [\"dataset1\", \"dataset2\"]"); auto valid_choices = env.valid_tokens("smooth_strat"); auto& smooth_arg = program.add_argument("--smooth-strat").help("Smooth strategy used in Bayes Network node initialization. Valid values: " + env.valid_values("smooth_strat")).default_value(env.get("smooth_strat")); @@ -288,7 +288,7 @@ void search(argparse::ArgumentParser& program) MPI_Comm_rank(MPI_COMM_WORLD, &mpi_config.rank); MPI_Comm_size(MPI_COMM_WORLD, &mpi_config.n_procs); if (mpi_config.n_procs < 2) { - throw std::runtime_error("Cannot use --compute with less than 2 mpi processes, try mpirun -np 2 ..."); + throw std::runtime_error("Cannot use --search with less than 2 mpi processes, try mpirun -np 2 ..."); } grid_search.go(mpi_config); if (mpi_config.rank == mpi_config.manager) { @@ -314,7 +314,7 @@ void experiment(argparse::ArgumentParser& program) MPI_Comm_rank(MPI_COMM_WORLD, &mpi_config.rank); MPI_Comm_size(MPI_COMM_WORLD, &mpi_config.n_procs); if (mpi_config.n_procs < 2) { - throw std::runtime_error("Cannot use --compute with less than 2 mpi processes, try mpirun -np 2 ..."); + throw std::runtime_error("Cannot use --experiment with less than 2 mpi processes, try mpirun -np 2 ..."); } grid_experiment.go(mpi_config); if (mpi_config.rank == mpi_config.manager) { @@ -322,8 +322,8 @@ void experiment(argparse::ArgumentParser& program) std::cout << "* Report of the computed hyperparameters" << std::endl; auto duration = timer.getDuration(); experiment.setDuration(duration); - // experiment.report(grid_experiment.numFiles() == 1); experiment.saveResult(); + experiment.report(grid_experiment.numFiles() == 1); std::cout << "Process took " << duration << std::endl; } MPI_Finalize(); @@ -344,11 +344,11 @@ int main(int argc, char** argv) assignModel(report_command); report_command.add_description("Report the computed hyperparameters of a model."); - // grid compute subparser + // grid search subparser argparse::ArgumentParser search_command("search"); search_command.add_description("Search using mpi the hyperparameters of a model."); assignModel(search_command); - add_compute_args(search_command); + add_search_args(search_command); // grid experiment subparser argparse::ArgumentParser experiment_command("experiment"); @@ -376,7 +376,7 @@ int main(int argc, char** argv) } } if (!found) { - throw std::runtime_error("You must specify one of the following commands: dump, report, compute\n"); + throw std::runtime_error("You must specify one of the following commands: dump, experiment, report, search \n"); } } catch (const exception& err) { diff --git a/src/grid/GridExperiment.cpp b/src/grid/GridExperiment.cpp index c273ece..67db3b7 100644 --- a/src/grid/GridExperiment.cpp +++ b/src/grid/GridExperiment.cpp @@ -95,7 +95,6 @@ namespace platform { } } } - platform::HyperParameters test_hyperparams; if (hyperparameters_file != "") { test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file, hyper_best); @@ -109,6 +108,7 @@ namespace platform { this->config.smooth_strategy = smooth_strat; this->config.n_folds = n_folds; this->config.seeds = seeds; + this->config.quiet = false; auto env = platform::DotEnv(); experiment.setTitle(title).setLanguage("c++").setLanguageVersion("gcc 14.1.1"); experiment.setDiscretizationAlgorithm(discretize_algo).setSmoothSrategy(smooth_strat); @@ -138,6 +138,8 @@ namespace platform { void GridExperiment::compile_results(json& results, json& all_results, std::string& model) { auto datasets = Datasets(false, Paths::datasets()); + nlohmann::json temp = all_results; // To restore the order of the data by dataset name + all_results = temp; for (const auto& result_item : all_results.items()) { // each result has the results of all the outer folds as each one were a different task auto dataset_name = result_item.key(); diff --git a/src/reports/ReportConsole.cpp b/src/reports/ReportConsole.cpp index 9317b53..ef0eb97 100644 --- a/src/reports/ReportConsole.cpp +++ b/src/reports/ReportConsole.cpp @@ -224,7 +224,7 @@ namespace platform { std::string ReportConsole::buildClassificationReport(json& result, std::string color) { std::stringstream oss; - if (result.find("confusion_matrices") == result.end()) + if (result.find("confusion_matrices") == result.end() || result["confusion_matrices"].size() == 0) return ""; bool second_header = false; int lines_header = 0;