Add command results to b_list

Rename tostring -> toString in models
Add datasets names to b_main command help - validation
This commit is contained in:
2024-03-10 12:16:02 +01:00
parent 4c847fc3f6
commit b3b3d9f1b9
9 changed files with 104 additions and 24 deletions

View File

@@ -85,13 +85,13 @@ int main(int argc, char** argv)
.default_value(std::string{ PATH } .default_value(std::string{ PATH }
); );
program.add_argument("-m", "--model") program.add_argument("-m", "--model")
.help("Model to use " + platform::Models::instance()->tostring()) .help("Model to use " + platform::Models::instance()->toString())
.action([](const std::string& value) { .action([](const std::string& value) {
static const std::vector<std::string> choices = platform::Models::instance()->getNames(); static const std::vector<std::string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) { if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value; return value;
} }
throw runtime_error("Model must be one of " + platform::Models::instance()->tostring()); throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
} }
); );
program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true); program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true);

View File

@@ -20,8 +20,8 @@ set(best_sources b_best.cc BestResults.cc Statistics.cc BestResultsExcel.cc)
list(TRANSFORM best_sources PREPEND best/) list(TRANSFORM best_sources PREPEND best/)
add_executable( add_executable(
b_best ${best_sources} main/Result.cc b_best ${best_sources} main/Result.cc
reports/ReportExcel.cc reports/ReportBase.cc reports/ExcelFile.cc common/Datasets.cc common/Dataset.cc) reports/ReportExcel.cc reports/ReportBase.cc reports/ExcelFile.cc common/Datasets.cc common/Dataset.cc main/Models.cc)
target_link_libraries(b_best Boost::boost "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp) target_link_libraries(b_best Boost::boost "${PyClassifiers}" "${BayesNet}" ArffFiles mdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}")
# b_grid # b_grid
set(grid_sources b_grid.cc GridSearch.cc GridData.cc) set(grid_sources b_grid.cc GridSearch.cc GridData.cc)

View File

@@ -126,4 +126,14 @@ namespace platform {
{ {
return datasets.find(name) != datasets.end(); return datasets.find(name) != datasets.end();
} }
std::string Datasets::toString() const
{
std::string result;
std::string sep = "";
for (const auto& d : datasets) {
result += sep + d.first;
sep = ", ";
}
return "{" + result + "}";
}
} }

View File

@@ -24,6 +24,7 @@ namespace platform {
std::pair<torch::Tensor&, torch::Tensor&> getTensors(const std::string& name); std::pair<torch::Tensor&, torch::Tensor&> getTensors(const std::string& name);
bool isDataset(const std::string& name) const; bool isDataset(const std::string& name) const;
void loadDataset(const std::string& name) const; void loadDataset(const std::string& name) const;
std::string toString() const;
}; };
}; };

View File

@@ -20,14 +20,14 @@ void assignModel(argparse::ArgumentParser& parser)
{ {
auto models = platform::Models::instance(); auto models = platform::Models::instance();
parser.add_argument("-m", "--model") parser.add_argument("-m", "--model")
.help("Model to use " + models->tostring()) .help("Model to use " + models->toString())
.required() .required()
.action([models](const std::string& value) { .action([models](const std::string& value) {
static const std::vector<std::string> choices = models->getNames(); static const std::vector<std::string> choices = models->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) { if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value; return value;
} }
throw std::runtime_error("Model must be one of " + models->tostring()); throw std::runtime_error("Model must be one of " + models->toString());
} }
); );
} }
@@ -259,7 +259,7 @@ int main(int argc, char** argv)
} }
} }
if (!found) { if (!found) {
throw std::runtime_error("You must specify one of the following commands: dump, report, compute, export\n"); throw std::runtime_error("You must specify one of the following commands: dump, report, compute\n");
} }
} }
catch (const exception& err) { catch (const exception& err) {

View File

@@ -1,5 +1,6 @@
#include <iostream> #include <iostream>
#include <locale> #include <locale>
#include <map>
#include <argparse/argparse.hpp> #include <argparse/argparse.hpp>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "common/Paths.h" #include "common/Paths.h"
@@ -28,15 +29,9 @@ std::string outputBalance(const std::string& balance)
return temp; return temp;
} }
int main(int argc, char** argv) void list_datasets(argparse::ArgumentParser& program)
{ {
auto datasets = platform::Datasets(false, platform::Paths::datasets()); auto datasets = platform::Datasets(false, platform::Paths::datasets());
argparse::ArgumentParser program("b_list", { platform_project_version.begin(), platform_project_version.end() });
program.add_argument("--excel")
.help("Output in Excel format")
.default_value(false)
.implicit_value(true);
program.parse_args(argc, argv);
auto excel = program.get<bool>("--excel"); auto excel = program.get<bool>("--excel");
locale mylocale(std::cout.getloc(), new separated); locale mylocale(std::cout.getloc(), new separated);
locale::global(mylocale); locale::global(mylocale);
@@ -70,11 +65,72 @@ int main(int argc, char** argv)
data[dataset]["classes"] = datasets.getNClasses(dataset); data[dataset]["classes"] = datasets.getNClasses(dataset);
data[dataset]["balance"] = oss.str(); data[dataset]["balance"] = oss.str();
} }
std::cout << Colors::RESET() << std::endl;
if (excel) { if (excel) {
auto report = platform::DatasetsExcel(); auto report = platform::DatasetsExcel();
report.report(data); report.report(data);
std::cout << "Output saved in " << report.getFileName() << std::endl; std::cout << std::endl << Colors::GREEN() << "Output saved in " << report.getFileName() << std::endl;
} }
return 0;
} }
void list_results(argparse::ArgumentParser& program)
{
std::cout << "Results" << std::endl;
}
int main(int argc, char** argv)
{
argparse::ArgumentParser program("b_list", { platform_project_version.begin(), platform_project_version.end() });
//
// datasets subparser
//
argparse::ArgumentParser datasets_command("datasets");
datasets_command.add_description("List datasets available in the platform.");
datasets_command.add_argument("--excel")
.help("Output in Excel format")
.default_value(false)
.implicit_value(true);
//
// results subparser
//
argparse::ArgumentParser results_command("results");
results_command.add_description("List the results of a given dataset.");
auto datasets = platform::Datasets(false, platform::Paths::datasets());
results_command.add_argument("-d", "--dataset")
.help("Dataset to use " + datasets.toString())
.required()
.action([](const std::string& value) {
auto datasets = platform::Datasets(false, platform::Paths::datasets());
static const std::vector<std::string> choices = datasets.getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw std::runtime_error("Dataset must be one of " + datasets.toString());
}
);
// Add subparsers
program.add_subparser(datasets_command);
program.add_subparser(results_command);
// Parse command line and execute
try {
program.parse_args(argc, argv);
bool found = false;
map<std::string, void(*)(argparse::ArgumentParser&)> commands = { {"datasets", &list_datasets}, {"results", &list_results} };
for (const auto& command : commands) {
if (program.is_subcommand_used(command.first)) {
std::invoke(command.second, program.at<argparse::ArgumentParser>(command.first));
found = true;
break;
}
}
if (!found) {
throw std::runtime_error("You must specify one of the following commands: datasets, results\n");
}
}
catch (const exception& err) {
cerr << err.what() << std::endl;
cerr << program;
exit(1);
}
std::cout << Colors::RESET() << std::endl;
return 0;
}

View File

@@ -36,13 +36,15 @@ namespace platform {
[](const pair<std::string, function<bayesnet::BaseClassifier* (void)>>& pair) { return pair.first; }); [](const pair<std::string, function<bayesnet::BaseClassifier* (void)>>& pair) { return pair.first; });
return names; return names;
} }
std::string Models::tostring() std::string Models::toString()
{ {
std::string result = ""; std::string result = "";
std::string sep = "";
for (const auto& pair : functionRegistry) { for (const auto& pair : functionRegistry) {
result += pair.first + ", "; result += sep + pair.first;
sep = ", ";
} }
return "{" + result.substr(0, result.size() - 2) + "}"; return "{" + result + "}";
} }
Registrar::Registrar(const std::string& name, function<bayesnet::BaseClassifier* (void)> classFactoryFunction) Registrar::Registrar(const std::string& name, function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
{ {

View File

@@ -31,7 +31,7 @@ namespace platform {
void registerFactoryFunction(const std::string& name, void registerFactoryFunction(const std::string& name,
function<bayesnet::BaseClassifier* (void)> classFactoryFunction); function<bayesnet::BaseClassifier* (void)> classFactoryFunction);
std::vector<string> getNames(); std::vector<string> getNames();
std::string tostring(); std::string toString();
}; };
class Registrar { class Registrar {

View File

@@ -15,18 +15,29 @@ using json = nlohmann::json;
void manageArguments(argparse::ArgumentParser& program) void manageArguments(argparse::ArgumentParser& program)
{ {
auto env = platform::DotEnv(); auto env = platform::DotEnv();
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name"); auto datasets = platform::Datasets(false, platform::Paths::datasets());
program.add_argument("-d", "--dataset")
.help("Dataset file name: " + datasets.toString())
.action([](const std::string& value) {
auto datasets = platform::Datasets(false, platform::Paths::datasets());
static const std::vector<std::string> choices_datasets(datasets.getNames());
if (find(choices_datasets.begin(), choices_datasets.end(), value) != choices_datasets.end()) {
return value;
}
throw std::runtime_error("Dataset must be one of: " + datasets.toString());
}
);
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment"); program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \ program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format."); "Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
program.add_argument("-m", "--model") program.add_argument("-m", "--model")
.help("Model to use " + platform::Models::instance()->tostring()) .help("Model to use: " + platform::Models::instance()->toString())
.action([](const std::string& value) { .action([](const std::string& value) {
static const std::vector<std::string> choices = platform::Models::instance()->getNames(); static const std::vector<std::string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) { if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value; return value;
} }
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring()); throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString());
} }
); );
program.add_argument("--title").default_value("").help("Experiment title"); program.add_argument("--title").default_value("").help("Experiment title");