From b3b3d9f1b903e79e396d51ab50b0d1ca3b57a174 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 10 Mar 2024 12:16:02 +0100 Subject: [PATCH] Add command results to b_list Rename tostring -> toString in models Add datasets names to b_main command help - validation --- sample/sample.cc | 4 +-- src/CMakeLists.txt | 4 +-- src/common/Datasets.cc | 10 ++++++ src/common/Datasets.h | 1 + src/grid/b_grid.cc | 6 ++-- src/list/b_list.cc | 76 ++++++++++++++++++++++++++++++++++++------ src/main/Models.cc | 8 +++-- src/main/Models.h | 2 +- src/main/b_main.cc | 17 ++++++++-- 9 files changed, 104 insertions(+), 24 deletions(-) diff --git a/sample/sample.cc b/sample/sample.cc index 7074846..491f82e 100644 --- a/sample/sample.cc +++ b/sample/sample.cc @@ -85,13 +85,13 @@ int main(int argc, char** argv) .default_value(std::string{ PATH } ); program.add_argument("-m", "--model") - .help("Model to use " + platform::Models::instance()->tostring()) + .help("Model to use " + platform::Models::instance()->toString()) .action([](const std::string& value) { static const std::vector choices = platform::Models::instance()->getNames(); if (find(choices.begin(), choices.end(), value) != choices.end()) { return value; } - throw runtime_error("Model must be one of " + platform::Models::instance()->tostring()); + throw runtime_error("Model must be one of " + platform::Models::instance()->toString()); } ); program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index da0686f..5da450b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -20,8 +20,8 @@ set(best_sources b_best.cc BestResults.cc Statistics.cc BestResultsExcel.cc) list(TRANSFORM best_sources PREPEND best/) add_executable( b_best ${best_sources} main/Result.cc - reports/ReportExcel.cc reports/ReportBase.cc reports/ExcelFile.cc common/Datasets.cc common/Dataset.cc) -target_link_libraries(b_best Boost::boost "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp) + reports/ReportExcel.cc reports/ReportBase.cc reports/ExcelFile.cc common/Datasets.cc common/Dataset.cc main/Models.cc) +target_link_libraries(b_best Boost::boost "${PyClassifiers}" "${BayesNet}" ArffFiles mdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}") # b_grid set(grid_sources b_grid.cc GridSearch.cc GridData.cc) diff --git a/src/common/Datasets.cc b/src/common/Datasets.cc index bfb419f..4e68d8a 100644 --- a/src/common/Datasets.cc +++ b/src/common/Datasets.cc @@ -126,4 +126,14 @@ namespace platform { { return datasets.find(name) != datasets.end(); } + std::string Datasets::toString() const + { + std::string result; + std::string sep = ""; + for (const auto& d : datasets) { + result += sep + d.first; + sep = ", "; + } + return "{" + result + "}"; + } } \ No newline at end of file diff --git a/src/common/Datasets.h b/src/common/Datasets.h index 4ead616..0f7f0cf 100644 --- a/src/common/Datasets.h +++ b/src/common/Datasets.h @@ -24,6 +24,7 @@ namespace platform { std::pair getTensors(const std::string& name); bool isDataset(const std::string& name) const; void loadDataset(const std::string& name) const; + std::string toString() const; }; }; diff --git a/src/grid/b_grid.cc b/src/grid/b_grid.cc index c621460..fb37c47 100644 --- a/src/grid/b_grid.cc +++ b/src/grid/b_grid.cc @@ -20,14 +20,14 @@ void assignModel(argparse::ArgumentParser& parser) { auto models = platform::Models::instance(); parser.add_argument("-m", "--model") - .help("Model to use " + models->tostring()) + .help("Model to use " + models->toString()) .required() .action([models](const std::string& value) { static const std::vector choices = models->getNames(); if (find(choices.begin(), choices.end(), value) != choices.end()) { return value; } - throw std::runtime_error("Model must be one of " + models->tostring()); + throw std::runtime_error("Model must be one of " + models->toString()); } ); } @@ -259,7 +259,7 @@ int main(int argc, char** argv) } } if (!found) { - throw std::runtime_error("You must specify one of the following commands: dump, report, compute, export\n"); + throw std::runtime_error("You must specify one of the following commands: dump, report, compute\n"); } } catch (const exception& err) { diff --git a/src/list/b_list.cc b/src/list/b_list.cc index 24240d7..1cbe4a4 100644 --- a/src/list/b_list.cc +++ b/src/list/b_list.cc @@ -1,5 +1,6 @@ #include #include +#include #include #include #include "common/Paths.h" @@ -28,15 +29,9 @@ std::string outputBalance(const std::string& balance) return temp; } -int main(int argc, char** argv) +void list_datasets(argparse::ArgumentParser& program) { auto datasets = platform::Datasets(false, platform::Paths::datasets()); - argparse::ArgumentParser program("b_list", { platform_project_version.begin(), platform_project_version.end() }); - program.add_argument("--excel") - .help("Output in Excel format") - .default_value(false) - .implicit_value(true); - program.parse_args(argc, argv); auto excel = program.get("--excel"); locale mylocale(std::cout.getloc(), new separated); locale::global(mylocale); @@ -70,11 +65,72 @@ int main(int argc, char** argv) data[dataset]["classes"] = datasets.getNClasses(dataset); data[dataset]["balance"] = oss.str(); } - std::cout << Colors::RESET() << std::endl; if (excel) { auto report = platform::DatasetsExcel(); report.report(data); - std::cout << "Output saved in " << report.getFileName() << std::endl; + std::cout << std::endl << Colors::GREEN() << "Output saved in " << report.getFileName() << std::endl; } - return 0; } + +void list_results(argparse::ArgumentParser& program) +{ + std::cout << "Results" << std::endl; +} + +int main(int argc, char** argv) +{ + argparse::ArgumentParser program("b_list", { platform_project_version.begin(), platform_project_version.end() }); + // + // datasets subparser + // + argparse::ArgumentParser datasets_command("datasets"); + datasets_command.add_description("List datasets available in the platform."); + datasets_command.add_argument("--excel") + .help("Output in Excel format") + .default_value(false) + .implicit_value(true); + // + // results subparser + // + argparse::ArgumentParser results_command("results"); + results_command.add_description("List the results of a given dataset."); + auto datasets = platform::Datasets(false, platform::Paths::datasets()); + results_command.add_argument("-d", "--dataset") + .help("Dataset to use " + datasets.toString()) + .required() + .action([](const std::string& value) { + auto datasets = platform::Datasets(false, platform::Paths::datasets()); + static const std::vector choices = datasets.getNames(); + if (find(choices.begin(), choices.end(), value) != choices.end()) { + return value; + } + throw std::runtime_error("Dataset must be one of " + datasets.toString()); + } + ); + // Add subparsers + program.add_subparser(datasets_command); + program.add_subparser(results_command); + // Parse command line and execute + try { + program.parse_args(argc, argv); + bool found = false; + map commands = { {"datasets", &list_datasets}, {"results", &list_results} }; + for (const auto& command : commands) { + if (program.is_subcommand_used(command.first)) { + std::invoke(command.second, program.at(command.first)); + found = true; + break; + } + } + if (!found) { + throw std::runtime_error("You must specify one of the following commands: datasets, results\n"); + } + } + catch (const exception& err) { + cerr << err.what() << std::endl; + cerr << program; + exit(1); + } + std::cout << Colors::RESET() << std::endl; + return 0; +} \ No newline at end of file diff --git a/src/main/Models.cc b/src/main/Models.cc index 10929e4..fc338f0 100644 --- a/src/main/Models.cc +++ b/src/main/Models.cc @@ -36,13 +36,15 @@ namespace platform { [](const pair>& pair) { return pair.first; }); return names; } - std::string Models::tostring() + std::string Models::toString() { std::string result = ""; + std::string sep = ""; for (const auto& pair : functionRegistry) { - result += pair.first + ", "; + result += sep + pair.first; + sep = ", "; } - return "{" + result.substr(0, result.size() - 2) + "}"; + return "{" + result + "}"; } Registrar::Registrar(const std::string& name, function classFactoryFunction) { diff --git a/src/main/Models.h b/src/main/Models.h index f303854..a0f526b 100644 --- a/src/main/Models.h +++ b/src/main/Models.h @@ -31,7 +31,7 @@ namespace platform { void registerFactoryFunction(const std::string& name, function classFactoryFunction); std::vector getNames(); - std::string tostring(); + std::string toString(); }; class Registrar { diff --git a/src/main/b_main.cc b/src/main/b_main.cc index 1999fff..33a6f25 100644 --- a/src/main/b_main.cc +++ b/src/main/b_main.cc @@ -15,18 +15,29 @@ using json = nlohmann::json; void manageArguments(argparse::ArgumentParser& program) { auto env = platform::DotEnv(); - program.add_argument("-d", "--dataset").default_value("").help("Dataset file name"); + auto datasets = platform::Datasets(false, platform::Paths::datasets()); + program.add_argument("-d", "--dataset") + .help("Dataset file name: " + datasets.toString()) + .action([](const std::string& value) { + auto datasets = platform::Datasets(false, platform::Paths::datasets()); + static const std::vector choices_datasets(datasets.getNames()); + if (find(choices_datasets.begin(), choices_datasets.end(), value) != choices_datasets.end()) { + return value; + } + throw std::runtime_error("Dataset must be one of: " + datasets.toString()); + } + ); program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment"); program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \ "Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format."); program.add_argument("-m", "--model") - .help("Model to use " + platform::Models::instance()->tostring()) + .help("Model to use: " + platform::Models::instance()->toString()) .action([](const std::string& value) { static const std::vector choices = platform::Models::instance()->getNames(); if (find(choices.begin(), choices.end(), value) != choices.end()) { return value; } - throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring()); + throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString()); } ); program.add_argument("--title").default_value("").help("Experiment title");