Continue with grid experiment

This commit is contained in:
2025-01-17 10:39:56 +01:00
parent 9a9a9fb17a
commit c1d5dd74e3
12 changed files with 238 additions and 85 deletions

View File

@@ -36,22 +36,22 @@ void add_experiment_args(argparse::ArgumentParser& program)
{
auto env = platform::DotEnv();
auto datasets = platform::Datasets(false, platform::Paths::datasets());
// auto& group = program.add_mutually_exclusive_group(true);
// group.add_argument("-d", "--dataset")
// .help("Dataset file name: " + datasets.toString())
// .default_value("all")
// .action([](const std::string& value) {
// auto datasets = platform::Datasets(false, platform::Paths::datasets());
// static std::vector<std::string> choices_datasets(datasets.getNames());
// choices_datasets.push_back("all");
// if (find(choices_datasets.begin(), choices_datasets.end(), value) != choices_datasets.end()) {
// return value;
// }
// throw std::runtime_error("Dataset must be one of: " + datasets.toString());
// }
// );
// group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector<std::string>());
// group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test.");
auto& group = program.add_mutually_exclusive_group(true);
group.add_argument("-d", "--dataset")
.help("Dataset file name: " + datasets.toString())
.default_value("all")
.action([](const std::string& value) {
auto datasets = platform::Datasets(false, platform::Paths::datasets());
static std::vector<std::string> choices_datasets(datasets.getNames());
choices_datasets.push_back("all");
if (find(choices_datasets.begin(), choices_datasets.end(), value) != choices_datasets.end()) {
return value;
}
throw std::runtime_error("Dataset must be one of: " + datasets.toString());
}
);
group.add_argument("--datasets").nargs(1, 50).help("Datasets file names 1..50 separated by spaces").default_value(std::vector<std::string>());
group.add_argument("--datasets-file").default_value("").help("Datasets file name. Mutually exclusive with dataset. This file should contain a list of datasets to test.");
program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment");
program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \
"Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format.");
@@ -83,11 +83,6 @@ void add_experiment_args(argparse::ArgumentParser& program)
for (auto choice : valid_choices) {
score_arg.choices(choice);
}
program.add_argument("--generate-fold-files").help("generate fold information in datasets_experiment folder").default_value(false).implicit_value(true);
program.add_argument("--graph").help("generate graphviz dot files with the model").default_value(false).implicit_value(true);
program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true);
program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true);
program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true);
program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true);
program.add_argument("-f", "--folds").help("Number of folds").default_value(stoi(env.get("n_folds"))).scan<'i', int>().action([](const std::string& value) {
try {
@@ -307,19 +302,10 @@ void search(argparse::ArgumentParser& program)
void experiment(argparse::ArgumentParser& program)
{
struct platform::ConfigGrid config;
config.model = program.get<std::string>("model");
config.score = program.get<std::string>("score");
config.discretize = program.get<bool>("discretize");
config.stratified = program.get<bool>("stratified");
config.smooth_strategy = program.get<std::string>("smooth-strat");
config.n_folds = program.get<int>("folds");
config.quiet = program.get<bool>("quiet");
config.seeds = program.get<std::vector<int>>("seeds");
auto env = platform::DotEnv();
config.platform = env.get("platform");
platform::Paths::createPath(platform::Paths::grid());
auto grid_experiment = platform::GridExperiment(config);
auto grid_experiment = platform::GridExperiment(program, config);
platform::Timer timer;
timer.start();
struct platform::ConfigMPI mpi_config;
@@ -333,6 +319,7 @@ void experiment(argparse::ArgumentParser& program)
grid_experiment.go(mpi_config);
if (mpi_config.rank == mpi_config.manager) {
auto results = grid_experiment.getResults();
//build_experiment_result(results);
std::cout << "****** RESULTS ********" << std::endl;
std::cout << results.dump(4) << std::endl;
// list_results(results, config.model);