Add folder parameter to best, grid and main

This commit is contained in:
2025-05-14 11:46:15 +02:00
parent 321e2a2f28
commit d6603dd638
9 changed files with 24 additions and 11 deletions

View File

@@ -44,6 +44,9 @@ int main(int argc, char** argv)
program.parse_args(argc, argv);
model = program.get<std::string>("model");
folder = program.get<std::string>("folder");
if (folder.back() != '/') {
folder += '/';
}
dataset = program.get<std::string>("dataset");
score = program.get<std::string>("score");
friedman = program.get<bool>("friedman");

View File

@@ -231,6 +231,7 @@ void experiment(argparse::ArgumentParser& program)
{
struct platform::ConfigGrid config;
auto arguments = platform::ArgumentsExperiment(program, platform::experiment_t::GRID);
auto path_results = arguments.getPathResults();
arguments.parse();
auto grid_experiment = platform::GridExperiment(arguments, config);
platform::Timer timer;
@@ -250,7 +251,7 @@ void experiment(argparse::ArgumentParser& program)
auto duration = timer.getDuration();
experiment.setDuration(duration);
if (grid_experiment.haveToSaveResults()) {
experiment.saveResult();
experiment.saveResult(path_results);
}
experiment.report();
std::cout << "Process took " << duration << std::endl;

View File

@@ -18,6 +18,7 @@ int main(int argc, char** argv)
*/
// Initialize the experiment class with the command line arguments
auto experiment = arguments.initializedExperiment();
auto path_results = arguments.getPathResults();
platform::Timer timer;
timer.start();
experiment.go();
@@ -27,7 +28,7 @@ int main(int argc, char** argv)
experiment.report();
}
if (arguments.haveToSaveResults()) {
experiment.saveResult();
experiment.saveResult(path_results);
}
if (arguments.doGraph()) {
experiment.saveGraph();

View File

@@ -13,6 +13,7 @@ namespace platform {
auto env = platform::DotEnv();
auto datasets = platform::Datasets(false, platform::Paths::datasets());
auto& group = arguments.add_mutually_exclusive_group(true);
group.add_argument("-d", "--dataset")
.help("Dataset file name: " + datasets.toString())
.default_value("all")
@@ -43,6 +44,7 @@ namespace platform {
}
);
arguments.add_argument("--title").default_value("").help("Experiment title");
arguments.add_argument("--folder").help("Results folder to use").default_value(platform::Paths::results());
arguments.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true);
auto valid_choices = env.valid_tokens("discretize_algo");
auto& disc_arg = arguments.add_argument("--discretize-algo").help("Algorithm to use in discretization. Valid values: " + env.valid_values("discretize_algo")).default_value(env.get("discretize_algo"));
@@ -103,6 +105,10 @@ namespace platform {
file_name = arguments.get<std::string>("dataset");
file_names = arguments.get<std::vector<std::string>>("datasets");
datasets_file = arguments.get<std::string>("datasets-file");
path_results = arguments.get<std::string>("folder");
if (path_results.back() != '/') {
path_results += '/';
}
model_name = arguments.get<std::string>("model");
discretize_dataset = arguments.get<bool>("discretize");
discretize_algo = arguments.get<std::string>("discretize-algo");
@@ -119,7 +125,7 @@ namespace platform {
hyper_best = arguments.get<bool>("hyper-best");
if (hyper_best) {
// Build the best results file_name
hyperparameters_file = platform::Paths::results() + platform::Paths::bestResultsFile(score, model_name);
hyperparameters_file = path_results + platform::Paths::bestResultsFile(score, model_name);
// ignore this parameter
hyperparameters = "{}";
} else {

View File

@@ -22,11 +22,13 @@ namespace platform {
bool isQuiet() const { return quiet; }
bool haveToSaveResults() const { return saveResults; }
bool doGraph() const { return graph; }
std::string getPathResults() const { return path_results; }
private:
Experiment experiment;
experiment_t type;
argparse::ArgumentParser& arguments;
std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat, score;
std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat;
std::string score, path_results;
json hyperparameters_json;
bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph, hyper_best;
std::vector<int> seeds;

View File

@@ -7,12 +7,12 @@
namespace platform {
using json = nlohmann::ordered_json;
void Experiment::saveResult()
void Experiment::saveResult(const std::string& path)
{
result.setSchemaVersion("1.0");
result.check();
result.save();
std::cout << "Result saved in " << Paths::results() << result.getFilename() << std::endl;
result.save(path);
std::cout << "Result saved in " << path << result.getFilename() << std::endl;
}
void Experiment::report()
{

View File

@@ -45,7 +45,7 @@ namespace platform {
std::vector<int> getRandomSeeds() const { return randomSeeds; }
void cross_validation(const std::string& fileName);
void go();
void saveResult();
void saveResult(const std::string& path);
void show();
void saveGraph();
void report();

View File

@@ -69,9 +69,9 @@ namespace platform {
platform::JsonValidator validator(platform::SchemaV1_0::schema);
return validator.validate(data);
}
void Result::save()
void Result::save(const std::string& path)
{
std::ofstream file(Paths::results() + getFilename());
std::ofstream file(path + getFilename());
file << data;
file.close();
}

View File

@@ -15,7 +15,7 @@ namespace platform {
public:
Result();
Result& load(const std::string& path, const std::string& filename);
void save();
void save(const std::string& path);
std::vector<std::string> check();
// Getters
json getJson();