Add command results to b_list

Rename tostring -> toString in models
Add datasets names to b_main command help - validation
This commit is contained in:
2024-03-10 12:16:02 +01:00
parent 4c847fc3f6
commit b3b3d9f1b9
9 changed files with 104 additions and 24 deletions

View File

@@ -85,13 +85,13 @@ int main(int argc, char** argv)
.default_value(std::string{ PATH }
);
program.add_argument("-m", "--model")
.help("Model to use " + platform::Models::instance()->tostring())
.help("Model to use " + platform::Models::instance()->toString())
.action([](const std::string& value) {
static const std::vector<std::string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw runtime_error("Model must be one of " + platform::Models::instance()->tostring());
throw runtime_error("Model must be one of " + platform::Models::instance()->toString());
}
);
program.add_argument("--discretize").help("Discretize input dataset").default_value(false).implicit_value(true);

View File

@@ -20,8 +20,8 @@ set(best_sources b_best.cc BestResults.cc Statistics.cc BestResultsExcel.cc)
list(TRANSFORM best_sources PREPEND best/)
add_executable(
b_best ${best_sources} main/Result.cc
reports/ReportExcel.cc reports/ReportBase.cc reports/ExcelFile.cc common/Datasets.cc common/Dataset.cc)
target_link_libraries(b_best Boost::boost "${TORCH_LIBRARIES}" "${XLSXWRITER_LIB}" ArffFiles mdlp)
reports/ReportExcel.cc reports/ReportBase.cc reports/ExcelFile.cc common/Datasets.cc common/Dataset.cc main/Models.cc)
target_link_libraries(b_best Boost::boost "${PyClassifiers}" "${BayesNet}" ArffFiles mdlp ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::python Boost::numpy "${XLSXWRITER_LIB}")
# b_grid
set(grid_sources b_grid.cc GridSearch.cc GridData.cc)

View File

@@ -126,4 +126,14 @@ namespace platform {
{
return datasets.find(name) != datasets.end();
}
std::string Datasets::toString() const
{
std::string result;
std::string sep = "";
for (const auto& d : datasets) {
result += sep + d.first;
sep = ", ";
}
return "{" + result + "}";
}
}

View File

@@ -24,6 +24,7 @@ namespace platform {
std::pair<torch::Tensor&, torch::Tensor&> getTensors(const std::string& name);
bool isDataset(const std::string& name) const;
void loadDataset(const std::string& name) const;
std::string toString() const;
};
};

View File

@@ -20,14 +20,14 @@ void assignModel(argparse::ArgumentParser& parser)
{
auto models = platform::Models::instance();
parser.add_argument("-m", "--model")
.help("Model to use " + models->tostring())
.help("Model to use " + models->toString())
.required()
.action([models](const std::string& value) {
static const std::vector<std::string> choices = models->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw std::runtime_error("Model must be one of " + models->tostring());
throw std::runtime_error("Model must be one of " + models->toString());
}
);
}
@@ -259,7 +259,7 @@ int main(int argc, char** argv)
}
}
if (!found) {
throw std::runtime_error("You must specify one of the following commands: dump, report, compute, export\n");
throw std::runtime_error("You must specify one of the following commands: dump, report, compute\n");
}
}
catch (const exception& err) {

View File

@@ -1,5 +1,6 @@
#include <iostream>
#include <locale>
#include <map>
#include <argparse/argparse.hpp>
#include <nlohmann/json.hpp>
#include "common/Paths.h"
@@ -28,15 +29,9 @@ std::string outputBalance(const std::string& balance)
return temp;
}
int main(int argc, char** argv)
void list_datasets(argparse::ArgumentParser& program)
{
auto datasets = platform::Datasets(false, platform::Paths::datasets());
argparse::ArgumentParser program("b_list", { platform_project_version.begin(), platform_project_version.end() });
program.add_argument("--excel")
.help("Output in Excel format")
.default_value(false)
.implicit_value(true);
program.parse_args(argc, argv);
auto excel = program.get<bool>("--excel");
locale mylocale(std::cout.getloc(), new separated);
locale::global(mylocale);
@@ -70,11 +65,72 @@ int main(int argc, char** argv)
data[dataset]["classes"] = datasets.getNClasses(dataset);
data[dataset]["balance"] = oss.str();
}
std::cout << Colors::RESET() << std::endl;
if (excel) {
auto report = platform::DatasetsExcel();
report.report(data);
std::cout << "Output saved in " << report.getFileName() << std::endl;
std::cout << std::endl << Colors::GREEN() << "Output saved in " << report.getFileName() << std::endl;
}
return 0;
}
void list_results(argparse::ArgumentParser& program)
{
std::cout << "Results" << std::endl;
}
int main(int argc, char** argv)
{
argparse::ArgumentParser program("b_list", { platform_project_version.begin(), platform_project_version.end() });
//
// datasets subparser
//
argparse::ArgumentParser datasets_command("datasets");
datasets_command.add_description("List datasets available in the platform.");
datasets_command.add_argument("--excel")
.help("Output in Excel format")
.default_value(false)
.implicit_value(true);
//
// results subparser
//
argparse::ArgumentParser results_command("results");
results_command.add_description("List the results of a given dataset.");
auto datasets = platform::Datasets(false, platform::Paths::datasets());
results_command.add_argument("-d", "--dataset")
.help("Dataset to use " + datasets.toString())
.required()
.action([](const std::string& value) {
auto datasets = platform::Datasets(false, platform::Paths::datasets());
static const std::vector<std::string> choices = datasets.getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw std::runtime_error("Dataset must be one of " + datasets.toString());
}
);
// Add subparsers
program.add_subparser(datasets_command);
program.add_subparser(results_command);
// Parse command line and execute
try {
program.parse_args(argc, argv);
bool found = false;
map<std::string, void(*)(argparse::ArgumentParser&)> commands = { {"datasets", &list_datasets}, {"results", &list_results} };
for (const auto& command : commands) {
if (program.is_subcommand_used(command.first)) {
std::invoke(command.second, program.at<argparse::ArgumentParser>(command.first));
found = true;
break;
}
}
if (!found) {
throw std::runtime_error("You must specify one of the following commands: datasets, results\n");
}
}
catch (const exception& err) {
cerr << err.what() << std::endl;
cerr << program;
exit(1);
}
std::cout << Colors::RESET() << std::endl;
return 0;
}

View File

@@ -36,13 +36,15 @@ namespace platform {
[](const pair<std::string, function<bayesnet::BaseClassifier* (void)>>& pair) { return pair.first; });
return names;
}
std::string Models::tostring()
std::string Models::toString()
{
std::string result = "";
std::string sep = "";
for (const auto& pair : functionRegistry) {
result += pair.first + ", ";
result += sep + pair.first;
sep = ", ";
}
return "{" + result.substr(0, result.size() - 2) + "}";
return "{" + result + "}";
}
Registrar::Registrar(const std::string& name, function<bayesnet::BaseClassifier* (void)> classFactoryFunction)
{

View File

@@ -31,7 +31,7 @@ namespace platform {
void registerFactoryFunction(const std::string& name,
function<bayesnet::BaseClassifier* (void)> classFactoryFunction);
std::vector<string> getNames();
std::string tostring();
std::string toString();
};
class Registrar {

View File

@@ -15,18 +15,29 @@ using json = nlohmann::json;
void manageArguments(argparse::ArgumentParser& program)
{
auto env = platform::DotEnv();
program.add_argument("-d", "--dataset").default_value("").help("Dataset file name");
auto datasets = platform::Datasets(false, platform::Paths::datasets());
program.add_argument("-d", "--dataset")
.help("Dataset file name: " + datasets.toString())
.action([](const std::string& value) {
auto datasets = platform::Datasets(false, platform::Paths::datasets());
static const std::vector<std::string> choices_datasets(datasets.getNames());
if (find(choices_datasets.begin(), choices_datasets.end(), value) != choices_datasets.end()) {
return value;
}
throw std::runtime_error("Dataset must be one of: " + datasets.toString());
}
);
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
program.add_argument("-m", "--model")
.help("Model to use " + platform::Models::instance()->tostring())
.help("Model to use: " + platform::Models::instance()->toString())
.action([](const std::string& value) {
static const std::vector<std::string> choices = platform::Models::instance()->getNames();
if (find(choices.begin(), choices.end(), value) != choices.end()) {
return value;
}
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring());
throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString());
}
);
program.add_argument("--title").default_value("").help("Experiment title");