Add folder parameter to best, grid and main

This commit is contained in:
2025-05-14 11:46:15 +02:00
parent 321e2a2f28
commit d6603dd638
9 changed files with 24 additions and 11 deletions

View File

@@ -44,6 +44,9 @@ int main(int argc, char** argv)
program.parse_args(argc, argv); program.parse_args(argc, argv);
model = program.get<std::string>("model"); model = program.get<std::string>("model");
folder = program.get<std::string>("folder"); folder = program.get<std::string>("folder");
if (folder.back() != '/') {
folder += '/';
}
dataset = program.get<std::string>("dataset"); dataset = program.get<std::string>("dataset");
score = program.get<std::string>("score"); score = program.get<std::string>("score");
friedman = program.get<bool>("friedman"); friedman = program.get<bool>("friedman");

View File

@@ -231,6 +231,7 @@ void experiment(argparse::ArgumentParser& program)
{ {
struct platform::ConfigGrid config; struct platform::ConfigGrid config;
auto arguments = platform::ArgumentsExperiment(program, platform::experiment_t::GRID); auto arguments = platform::ArgumentsExperiment(program, platform::experiment_t::GRID);
auto path_results = arguments.getPathResults();
arguments.parse(); arguments.parse();
auto grid_experiment = platform::GridExperiment(arguments, config); auto grid_experiment = platform::GridExperiment(arguments, config);
platform::Timer timer; platform::Timer timer;
@@ -250,7 +251,7 @@ void experiment(argparse::ArgumentParser& program)
auto duration = timer.getDuration(); auto duration = timer.getDuration();
experiment.setDuration(duration); experiment.setDuration(duration);
if (grid_experiment.haveToSaveResults()) { if (grid_experiment.haveToSaveResults()) {
experiment.saveResult(); experiment.saveResult(path_results);
} }
experiment.report(); experiment.report();
std::cout << "Process took " << duration << std::endl; std::cout << "Process took " << duration << std::endl;

View File

@@ -18,6 +18,7 @@ int main(int argc, char** argv)
*/ */
// Initialize the experiment class with the command line arguments // Initialize the experiment class with the command line arguments
auto experiment = arguments.initializedExperiment(); auto experiment = arguments.initializedExperiment();
auto path_results = arguments.getPathResults();
platform::Timer timer; platform::Timer timer;
timer.start(); timer.start();
experiment.go(); experiment.go();
@@ -27,7 +28,7 @@ int main(int argc, char** argv)
experiment.report(); experiment.report();
} }
if (arguments.haveToSaveResults()) { if (arguments.haveToSaveResults()) {
experiment.saveResult(); experiment.saveResult(path_results);
} }
if (arguments.doGraph()) { if (arguments.doGraph()) {
experiment.saveGraph(); experiment.saveGraph();

View File

@@ -13,6 +13,7 @@ namespace platform {
auto env = platform::DotEnv(); auto env = platform::DotEnv();
auto datasets = platform::Datasets(false, platform::Paths::datasets()); auto datasets = platform::Datasets(false, platform::Paths::datasets());
auto& group = arguments.add_mutually_exclusive_group(true); auto& group = arguments.add_mutually_exclusive_group(true);
group.add_argument("-d", "--dataset") group.add_argument("-d", "--dataset")
.help("Dataset file name: " + datasets.toString()) .help("Dataset file name: " + datasets.toString())
.default_value("all") .default_value("all")
@@ -43,6 +44,7 @@ namespace platform {
} }
); );
arguments.add_argument("--title").default_value("").help("Experiment title"); arguments.add_argument("--title").default_value("").help("Experiment title");
arguments.add_argument("--folder").help("Results folder to use").default_value(platform::Paths::results());
arguments.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); arguments.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
auto valid_choices = env.valid_tokens("discretize_algo"); auto valid_choices = env.valid_tokens("discretize_algo");
auto& disc_arg = arguments.add_argument("--discretize-algo").help("Algorithm to use in discretization. Valid values: " + env.valid_values("discretize_algo")).default_value(env.get("discretize_algo")); auto& disc_arg = arguments.add_argument("--discretize-algo").help("Algorithm to use in discretization. Valid values: " + env.valid_values("discretize_algo")).default_value(env.get("discretize_algo"));
@@ -103,6 +105,10 @@ namespace platform {
file_name = arguments.get<std::string>("dataset"); file_name = arguments.get<std::string>("dataset");
file_names = arguments.get<std::vector<std::string>>("datasets"); file_names = arguments.get<std::vector<std::string>>("datasets");
datasets_file = arguments.get<std::string>("datasets-file"); datasets_file = arguments.get<std::string>("datasets-file");
path_results = arguments.get<std::string>("folder");
if (path_results.back() != '/') {
path_results += '/';
}
model_name = arguments.get<std::string>("model"); model_name = arguments.get<std::string>("model");
discretize_dataset = arguments.get<bool>("discretize"); discretize_dataset = arguments.get<bool>("discretize");
discretize_algo = arguments.get<std::string>("discretize-algo"); discretize_algo = arguments.get<std::string>("discretize-algo");
@@ -119,7 +125,7 @@ namespace platform {
hyper_best = arguments.get<bool>("hyper-best"); hyper_best = arguments.get<bool>("hyper-best");
if (hyper_best) { if (hyper_best) {
// Build the best results file_name // Build the best results file_name
hyperparameters_file = platform::Paths::results() + platform::Paths::bestResultsFile(score, model_name); hyperparameters_file = path_results + platform::Paths::bestResultsFile(score, model_name);
// ignore this parameter // ignore this parameter
hyperparameters = "{}"; hyperparameters = "{}";
} else { } else {

View File

@@ -22,11 +22,13 @@ namespace platform {
bool isQuiet() const { return quiet; } bool isQuiet() const { return quiet; }
bool haveToSaveResults() const { return saveResults; } bool haveToSaveResults() const { return saveResults; }
bool doGraph() const { return graph; } bool doGraph() const { return graph; }
std::string getPathResults() const { return path_results; }
private: private:
Experiment experiment; Experiment experiment;
experiment_t type; experiment_t type;
argparse::ArgumentParser& arguments; argparse::ArgumentParser& arguments;
std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat, score; std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat;
std::string score, path_results;
json hyperparameters_json; json hyperparameters_json;
bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph, hyper_best; bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph, hyper_best;
std::vector<int> seeds; std::vector<int> seeds;

View File

@@ -7,12 +7,12 @@
namespace platform { namespace platform {
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
void Experiment::saveResult() void Experiment::saveResult(const std::string& path)
{ {
result.setSchemaVersion("1.0"); result.setSchemaVersion("1.0");
result.check(); result.check();
result.save(); result.save(path);
std::cout << "Result saved in " << Paths::results() << result.getFilename() << std::endl; std::cout << "Result saved in " << path << result.getFilename() << std::endl;
} }
void Experiment::report() void Experiment::report()
{ {

View File

@@ -45,7 +45,7 @@ namespace platform {
std::vector<int> getRandomSeeds() const { return randomSeeds; } std::vector<int> getRandomSeeds() const { return randomSeeds; }
void cross_validation(const std::string& fileName); void cross_validation(const std::string& fileName);
void go(); void go();
void saveResult(); void saveResult(const std::string& path);
void show(); void show();
void saveGraph(); void saveGraph();
void report(); void report();

View File

@@ -69,9 +69,9 @@ namespace platform {
platform::JsonValidator validator(platform::SchemaV1_0::schema); platform::JsonValidator validator(platform::SchemaV1_0::schema);
return validator.validate(data); return validator.validate(data);
} }
void Result::save() void Result::save(const std::string& path)
{ {
std::ofstream file(Paths::results() + getFilename()); std::ofstream file(path + getFilename());
file << data; file << data;
file.close(); file.close();
} }

View File

@@ -15,7 +15,7 @@ namespace platform {
public: public:
Result(); Result();
Result& load(const std::string& path, const std::string& filename); Result& load(const std::string& path, const std::string& filename);
void save(); void save(const std::string& path);
std::vector<std::string> check(); std::vector<std::string> check();
// Getters // Getters
json getJson(); json getJson();