|
|
|
@@ -1,6 +1,7 @@
|
|
|
|
|
#include <iostream>
|
|
|
|
|
#include <argparse/argparse.hpp>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
#include <nlohmann/json.hpp>
|
|
|
|
|
#include <mpi.h>
|
|
|
|
|
#include "DotEnv.h"
|
|
|
|
@@ -15,23 +16,24 @@
|
|
|
|
|
using json = nlohmann::json;
|
|
|
|
|
const int MAXL = 133;
|
|
|
|
|
|
|
|
|
|
void manageArguments(argparse::ArgumentParser& program)
|
|
|
|
|
void assignModel(argparse::ArgumentParser& parser)
|
|
|
|
|
{
|
|
|
|
|
auto env = platform::DotEnv();
|
|
|
|
|
auto& group = program.add_mutually_exclusive_group(true);
|
|
|
|
|
program.add_argument("-m", "--model")
|
|
|
|
|
.help("Model to use " + platform::Models::instance()->tostring())
|
|
|
|
|
.action([](const std::string& value) {
|
|
|
|
|
static const std::vector<std::string> choices = platform::Models::instance()->getNames();
|
|
|
|
|
auto models = platform::Models::instance();
|
|
|
|
|
parser.add_argument("-m", "--model")
|
|
|
|
|
.help("Model to use " + models->tostring())
|
|
|
|
|
.required()
|
|
|
|
|
.action([models](const std::string& value) {
|
|
|
|
|
static const std::vector<std::string> choices = models->getNames();
|
|
|
|
|
if (find(choices.begin(), choices.end(), value) != choices.end()) {
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
throw std::runtime_error("Model must be one of " + platform::Models::instance()->tostring());
|
|
|
|
|
throw std::runtime_error("Model must be one of " + models->tostring());
|
|
|
|
|
}
|
|
|
|
|
);
|
|
|
|
|
group.add_argument("--dump").help("Show the grid combinations").default_value(false).implicit_value(true);
|
|
|
|
|
group.add_argument("--report").help("Report the computed hyperparameters").default_value(false).implicit_value(true);
|
|
|
|
|
group.add_argument("--compute").help("Perform computation of the grid output hyperparameters").default_value(false).implicit_value(true);
|
|
|
|
|
}
|
|
|
|
|
void add_compute_args(argparse::ArgumentParser& program)
|
|
|
|
|
{
|
|
|
|
|
auto env = platform::DotEnv();
|
|
|
|
|
program.add_argument("--discretize").help("Discretize input datasets").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
|
|
|
|
|
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
|
|
|
|
|
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
|
|
|
|
@@ -70,11 +72,19 @@ void manageArguments(argparse::ArgumentParser& program)
|
|
|
|
|
auto seed_values = env.getSeeds();
|
|
|
|
|
program.add_argument("-s", "--seeds").nargs(1, 10).help("Random seeds. Set to -1 to have pseudo random").scan<'i', int>().default_value(seed_values);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string headerLine(const std::string& text, int utf = 0)
|
|
|
|
|
{
|
|
|
|
|
int n = MAXL - text.length() - 3;
|
|
|
|
|
n = n < 0 ? 0 : n;
|
|
|
|
|
return "* " + text + std::string(n + utf, ' ') + "*\n";
|
|
|
|
|
}
|
|
|
|
|
void list_dump(std::string& model)
|
|
|
|
|
{
|
|
|
|
|
auto data = platform::GridData(platform::Paths::grid_input(model));
|
|
|
|
|
std::cout << Colors::MAGENTA() << "Listing configuration input file (Grid)" << std::endl << std::endl;
|
|
|
|
|
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
|
|
|
|
|
std::cout << headerLine("Listing configuration input file (Grid)");
|
|
|
|
|
std::cout << headerLine("Model: " + model);
|
|
|
|
|
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
|
|
|
|
|
int index = 0;
|
|
|
|
|
int max_hyper = 15;
|
|
|
|
|
int max_dataset = 7;
|
|
|
|
@@ -96,17 +106,11 @@ void list_dump(std::string& model)
|
|
|
|
|
std::cout << color;
|
|
|
|
|
auto num_combinations = data.getNumCombinations(item.first);
|
|
|
|
|
std::cout << setw(3) << fixed << right << ++index << left << " " << setw(max_dataset) << item.first
|
|
|
|
|
<< " " << setw(5) << right << num_combinations << " " << setw(max_hyper) << item.second.dump() << std::endl;
|
|
|
|
|
<< " " << setw(5) << right << num_combinations << " " << setw(max_hyper) << left << item.second.dump() << std::endl;
|
|
|
|
|
odd = !odd;
|
|
|
|
|
}
|
|
|
|
|
std::cout << Colors::RESET() << std::endl;
|
|
|
|
|
}
|
|
|
|
|
std::string headerLine(const std::string& text, int utf = 0)
|
|
|
|
|
{
|
|
|
|
|
int n = MAXL - text.length() - 3;
|
|
|
|
|
n = n < 0 ? 0 : n;
|
|
|
|
|
return "* " + text + std::string(n + utf, ' ') + "*\n";
|
|
|
|
|
}
|
|
|
|
|
void list_results(json& results, std::string& model)
|
|
|
|
|
{
|
|
|
|
|
std::cout << Colors::MAGENTA() << std::string(MAXL, '*') << std::endl;
|
|
|
|
@@ -155,77 +159,136 @@ void list_results(json& results, std::string& model)
|
|
|
|
|
/*
|
|
|
|
|
* Main
|
|
|
|
|
*/
|
|
|
|
|
int main(int argc, char** argv)
|
|
|
|
|
void dump(argparse::ArgumentParser& program)
|
|
|
|
|
{
|
|
|
|
|
argparse::ArgumentParser program("b_grid", { project_version.begin(), project_version.end() });
|
|
|
|
|
manageArguments(program);
|
|
|
|
|
auto model = program.get<std::string>("model");
|
|
|
|
|
list_dump(model);
|
|
|
|
|
}
|
|
|
|
|
void report(argparse::ArgumentParser& program)
|
|
|
|
|
{
|
|
|
|
|
// List results
|
|
|
|
|
struct platform::ConfigGrid config;
|
|
|
|
|
bool dump, compute;
|
|
|
|
|
try {
|
|
|
|
|
program.parse_args(argc, argv);
|
|
|
|
|
config.model = program.get<std::string>("model");
|
|
|
|
|
config.score = program.get<std::string>("score");
|
|
|
|
|
config.discretize = program.get<bool>("discretize");
|
|
|
|
|
config.stratified = program.get<bool>("stratified");
|
|
|
|
|
config.n_folds = program.get<int>("folds");
|
|
|
|
|
config.quiet = program.get<bool>("quiet");
|
|
|
|
|
config.only = program.get<bool>("only");
|
|
|
|
|
config.seeds = program.get<std::vector<int>>("seeds");
|
|
|
|
|
config.nested = program.get<int>("nested");
|
|
|
|
|
config.continue_from = program.get<std::string>("continue");
|
|
|
|
|
if (config.continue_from == platform::GridSearch::NO_CONTINUE() && config.only) {
|
|
|
|
|
throw std::runtime_error("Cannot use --only without --continue");
|
|
|
|
|
}
|
|
|
|
|
dump = program.get<bool>("dump");
|
|
|
|
|
compute = program.get<bool>("compute");
|
|
|
|
|
if (dump && (config.continue_from != platform::GridSearch::NO_CONTINUE() || config.only)) {
|
|
|
|
|
throw std::runtime_error("Cannot use --dump with --continue or --only");
|
|
|
|
|
}
|
|
|
|
|
auto excluded = program.get<std::string>("exclude");
|
|
|
|
|
config.excluded = json::parse(excluded);
|
|
|
|
|
config.model = program.get<std::string>("model");
|
|
|
|
|
auto grid_search = platform::GridSearch(config);
|
|
|
|
|
auto results = grid_search.loadResults();
|
|
|
|
|
if (results.empty()) {
|
|
|
|
|
std::cout << "** No results found" << std::endl;
|
|
|
|
|
} else {
|
|
|
|
|
list_results(results, config.model);
|
|
|
|
|
}
|
|
|
|
|
catch (const exception& err) {
|
|
|
|
|
cerr << err.what() << std::endl;
|
|
|
|
|
cerr << program;
|
|
|
|
|
exit(1);
|
|
|
|
|
}
|
|
|
|
|
void exportResults(argparse::ArgumentParser& program)
|
|
|
|
|
{
|
|
|
|
|
// Generate a grid_<model_name>.json file with the results of the grid search
|
|
|
|
|
// this file can be used by b_main to run the model with the best hyperparameters
|
|
|
|
|
struct platform::ConfigGrid config;
|
|
|
|
|
config.model = program.get<std::string>("model");
|
|
|
|
|
auto grid_search = platform::GridSearch(config);
|
|
|
|
|
auto results = grid_search.loadResults();
|
|
|
|
|
auto output = json::array();
|
|
|
|
|
if (results.empty()) {
|
|
|
|
|
std::cout << "** No results found" << std::endl;
|
|
|
|
|
} else {
|
|
|
|
|
grid_search.exportResults(results);
|
|
|
|
|
std::cout << "Exported results to " << platform::Paths::grid_export(config.model) << std::endl;
|
|
|
|
|
}
|
|
|
|
|
/*
|
|
|
|
|
* Begin Processing
|
|
|
|
|
*/
|
|
|
|
|
}
|
|
|
|
|
void compute(argparse::ArgumentParser& program)
|
|
|
|
|
{
|
|
|
|
|
struct platform::ConfigGrid config;
|
|
|
|
|
config.model = program.get<std::string>("model");
|
|
|
|
|
config.score = program.get<std::string>("score");
|
|
|
|
|
config.discretize = program.get<bool>("discretize");
|
|
|
|
|
config.stratified = program.get<bool>("stratified");
|
|
|
|
|
config.n_folds = program.get<int>("folds");
|
|
|
|
|
config.quiet = program.get<bool>("quiet");
|
|
|
|
|
config.only = program.get<bool>("only");
|
|
|
|
|
config.seeds = program.get<std::vector<int>>("seeds");
|
|
|
|
|
config.nested = program.get<int>("nested");
|
|
|
|
|
config.continue_from = program.get<std::string>("continue");
|
|
|
|
|
if (config.continue_from == platform::GridSearch::NO_CONTINUE() && config.only) {
|
|
|
|
|
throw std::runtime_error("Cannot use --only without --continue");
|
|
|
|
|
}
|
|
|
|
|
auto excluded = program.get<std::string>("exclude");
|
|
|
|
|
config.excluded = json::parse(excluded);
|
|
|
|
|
|
|
|
|
|
auto env = platform::DotEnv();
|
|
|
|
|
config.platform = env.get("platform");
|
|
|
|
|
platform::Paths::createPath(platform::Paths::grid());
|
|
|
|
|
auto grid_search = platform::GridSearch(config);
|
|
|
|
|
platform::Timer timer;
|
|
|
|
|
timer.start();
|
|
|
|
|
if (dump) {
|
|
|
|
|
list_dump(config.model);
|
|
|
|
|
} else {
|
|
|
|
|
if (compute) {
|
|
|
|
|
struct platform::ConfigMPI mpi_config;
|
|
|
|
|
mpi_config.manager = 0; // which process is the manager
|
|
|
|
|
MPI_Init(&argc, &argv);
|
|
|
|
|
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_config.rank);
|
|
|
|
|
MPI_Comm_size(MPI_COMM_WORLD, &mpi_config.n_procs);
|
|
|
|
|
if (mpi_config.n_procs < 2) {
|
|
|
|
|
throw std::runtime_error("Cannot use --compute with less than 2 mpi processes, try mpirun -np 2 ...");
|
|
|
|
|
}
|
|
|
|
|
grid_search.go(mpi_config);
|
|
|
|
|
if (mpi_config.rank == mpi_config.manager) {
|
|
|
|
|
auto results = grid_search.loadResults();
|
|
|
|
|
list_results(results, config.model);
|
|
|
|
|
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
|
|
|
|
}
|
|
|
|
|
MPI_Finalize();
|
|
|
|
|
} else {
|
|
|
|
|
// List results
|
|
|
|
|
auto results = grid_search.loadResults();
|
|
|
|
|
if (results.empty()) {
|
|
|
|
|
std::cout << "** No results found" << std::endl;
|
|
|
|
|
} else {
|
|
|
|
|
list_results(results, config.model);
|
|
|
|
|
struct platform::ConfigMPI mpi_config;
|
|
|
|
|
mpi_config.manager = 0; // which process is the manager
|
|
|
|
|
MPI_Init(nullptr, nullptr);
|
|
|
|
|
MPI_Comm_rank(MPI_COMM_WORLD, &mpi_config.rank);
|
|
|
|
|
MPI_Comm_size(MPI_COMM_WORLD, &mpi_config.n_procs);
|
|
|
|
|
if (mpi_config.n_procs < 2) {
|
|
|
|
|
throw std::runtime_error("Cannot use --compute with less than 2 mpi processes, try mpirun -np 2 ...");
|
|
|
|
|
}
|
|
|
|
|
grid_search.go(mpi_config);
|
|
|
|
|
if (mpi_config.rank == mpi_config.manager) {
|
|
|
|
|
auto results = grid_search.loadResults();
|
|
|
|
|
list_results(results, config.model);
|
|
|
|
|
std::cout << "Process took " << timer.getDurationString() << std::endl;
|
|
|
|
|
}
|
|
|
|
|
MPI_Finalize();
|
|
|
|
|
}
|
|
|
|
|
int main(int argc, char** argv)
|
|
|
|
|
{
|
|
|
|
|
//
|
|
|
|
|
// Manage arguments
|
|
|
|
|
//
|
|
|
|
|
argparse::ArgumentParser program("b_grid", { project_version.begin(), project_version.end() });
|
|
|
|
|
// grid dump subparser
|
|
|
|
|
argparse::ArgumentParser dump_command("dump");
|
|
|
|
|
dump_command.add_description("Dump the combinations of hyperparameters of a model.");
|
|
|
|
|
assignModel(dump_command);
|
|
|
|
|
|
|
|
|
|
// grid report subparser
|
|
|
|
|
argparse::ArgumentParser report_command("report");
|
|
|
|
|
assignModel(report_command);
|
|
|
|
|
report_command.add_description("Report the computed hyperparameters of a model.");
|
|
|
|
|
|
|
|
|
|
// grid compute subparser
|
|
|
|
|
argparse::ArgumentParser compute_command("compute");
|
|
|
|
|
compute_command.add_description("Compute using mpi the hyperparameters of a model.");
|
|
|
|
|
assignModel(compute_command);
|
|
|
|
|
add_compute_args(compute_command);
|
|
|
|
|
|
|
|
|
|
// grid export subparser
|
|
|
|
|
argparse::ArgumentParser export_command("export");
|
|
|
|
|
assignModel(export_command);
|
|
|
|
|
export_command.add_description("Export the computed hyperparameters to a file readable by b_main.");
|
|
|
|
|
|
|
|
|
|
program.add_subparser(dump_command);
|
|
|
|
|
program.add_subparser(report_command);
|
|
|
|
|
program.add_subparser(compute_command);
|
|
|
|
|
program.add_subparser(export_command);
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// Process options
|
|
|
|
|
//
|
|
|
|
|
try {
|
|
|
|
|
program.parse_args(argc, argv);
|
|
|
|
|
bool found = false;
|
|
|
|
|
map<std::string, void(*)(argparse::ArgumentParser&)> commands =
|
|
|
|
|
{ {"dump", &dump}, {"report", &report}, {"export", &exportResults}, {"compute", &compute} };
|
|
|
|
|
for (const auto& command : commands) {
|
|
|
|
|
if (program.is_subcommand_used(command.first)) {
|
|
|
|
|
std::invoke(command.second, program.at<argparse::ArgumentParser>(command.first));
|
|
|
|
|
found = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!found) {
|
|
|
|
|
throw std::runtime_error("You must specify one of the following commands: dump, report, compute, export\n");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
catch (const exception& err) {
|
|
|
|
|
cerr << err.what() << std::endl;
|
|
|
|
|
cerr << program;
|
|
|
|
|
exit(1);
|
|
|
|
|
}
|
|
|
|
|
std::cout << "Done!" << std::endl;
|
|
|
|
|
return 0;
|
|
|
|
|