Fix errors in grid Experiment

This commit is contained in:
2025-01-19 13:51:51 +01:00
parent 2a9652b450
commit 07e3cc9599
6 changed files with 9 additions and 3 deletions

View File

@@ -284,6 +284,7 @@ int main(int argc, char** argv)
argparse::ArgumentParser experiment_command("experiment");
experiment_command.add_description("Experiment like b_main using mpi.");
auto arguments = platform::ArgumentsExperiment(experiment_command, platform::experiment_t::GRID);
arguments.add_arguments();
program.add_subparser(dump_command);
program.add_subparser(report_command);
program.add_subparser(search_command);

View File

@@ -11,6 +11,7 @@ int main(int argc, char** argv)
{
argparse::ArgumentParser program("b_main", { platform_project_version.begin(), platform_project_version.end() });
auto arguments = platform::ArgumentsExperiment(program, platform::experiment_t::NORMAL);
arguments.add_arguments();
arguments.parse_args(argc, argv);
/*
* Begin Processing

View File

@@ -12,6 +12,8 @@ namespace platform {
GridExperiment::GridExperiment(ArgumentsExperiment& program, struct ConfigGrid& config) : arguments(program), GridBase(config)
{
experiment = arguments.initializedExperiment();
filesToTest = arguments.getFilesToTest();
saveResults = arguments.haveToSaveResults();
this->config.model = experiment.getModel();
this->config.score = experiment.getScore();
this->config.discretize = experiment.isDiscretized();

View File

@@ -29,7 +29,7 @@ namespace platform {
ArgumentsExperiment& arguments;
Experiment experiment;
json computed_results;
bool saveResults;
bool saveResults = false;
std::vector<std::string> filesToTest;
void save(json& results);
json initializeResults();

View File

@@ -6,6 +6,9 @@
#include "ArgumentsExperiment.h"
namespace platform {
ArgumentsExperiment::ArgumentsExperiment(argparse::ArgumentParser& program, experiment_t type) : arguments{ program }, type{ type }
{
}
void ArgumentsExperiment::add_arguments()
{
auto env = platform::DotEnv();
auto datasets = platform::Datasets(false, platform::Paths::datasets());
@@ -106,7 +109,6 @@ namespace platform {
smooth_strat = arguments.get<std::string>("smooth-strat");
stratified = arguments.get<bool>("stratified");
quiet = arguments.get<bool>("quiet");
n_folds = arguments.get<int>("folds");
score = arguments.get<std::string>("score");
seeds = arguments.get<std::vector<int>>("seeds");
@@ -196,7 +198,6 @@ namespace platform {
}
}
}
if (hyperparameters_file != "") {
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file, hyper_best);
} else {

View File

@@ -15,6 +15,7 @@ namespace platform {
ArgumentsExperiment(argparse::ArgumentParser& program, experiment_t type);
~ArgumentsExperiment() = default;
std::vector<std::string> getFilesToTest() const { return filesToTest; }
void add_arguments();
void parse_args(int argc, char** argv);
void parse();
Experiment& initializedExperiment();