Refactor arguments management for Experimentation
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include <cstddef>
|
||||
#include "common/DotEnv.h"
|
||||
#include "common/Paths.h"
|
||||
#include "common/DotEnv.h"
|
||||
#include "GridBase.h"
|
||||
|
||||
namespace platform {
|
||||
@@ -9,6 +10,8 @@ namespace platform {
|
||||
GridBase::GridBase(struct ConfigGrid& config)
|
||||
{
|
||||
this->config = config;
|
||||
auto env = platform::DotEnv();
|
||||
this->config.platform = env.get("platform");
|
||||
|
||||
}
|
||||
void GridBase::validate_config()
|
||||
|
@@ -8,120 +8,18 @@
|
||||
#include "GridExperiment.h"
|
||||
|
||||
namespace platform {
|
||||
GridExperiment::GridExperiment(argparse::ArgumentParser& program, struct ConfigGrid& config) : arguments(program), GridBase(config)
|
||||
// GridExperiment::GridExperiment(argparse::ArgumentParser& program, struct ConfigGrid& config) : arguments(program), GridBase(config)
|
||||
GridExperiment::GridExperiment(ArgumentsExperiment& program, struct ConfigGrid& config) : arguments(program), GridBase(config)
|
||||
{
|
||||
std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat, score;
|
||||
json hyperparameters_json;
|
||||
bool discretize_dataset, stratified, hyper_best;
|
||||
std::vector<int> seeds;
|
||||
std::vector<std::string> file_names;
|
||||
int n_folds;
|
||||
file_name = program.get<std::string>("dataset");
|
||||
file_names = program.get<std::vector<std::string>>("datasets");
|
||||
datasets_file = program.get<std::string>("datasets-file");
|
||||
model_name = program.get<std::string>("model");
|
||||
discretize_dataset = program.get<bool>("discretize");
|
||||
saveResults = program.get<bool>("save");
|
||||
discretize_algo = program.get<std::string>("discretize-algo");
|
||||
smooth_strat = program.get<std::string>("smooth-strat");
|
||||
stratified = program.get<bool>("stratified");
|
||||
n_folds = program.get<int>("folds");
|
||||
score = program.get<std::string>("score");
|
||||
seeds = program.get<std::vector<int>>("seeds");
|
||||
auto hyperparameters = program.get<std::string>("hyperparameters");
|
||||
hyperparameters_json = json::parse(hyperparameters);
|
||||
hyperparameters_file = program.get<std::string>("hyper-file");
|
||||
hyper_best = program.get<bool>("hyper-best");
|
||||
if (hyper_best) {
|
||||
// Build the best results file_name
|
||||
hyperparameters_file = platform::Paths::results() + platform::Paths::bestResultsFile(score, model_name);
|
||||
// ignore this parameter
|
||||
hyperparameters = "{}";
|
||||
} else {
|
||||
if (hyperparameters_file != "" && hyperparameters != "{}") {
|
||||
throw runtime_error("hyperparameters and hyper_file are mutually exclusive");
|
||||
}
|
||||
}
|
||||
title = program.get<std::string>("title");
|
||||
if (title == "" && file_name == "all") {
|
||||
throw runtime_error("title is mandatory if all datasets are to be tested");
|
||||
}
|
||||
auto datasets = platform::Datasets(false, platform::Paths::datasets());
|
||||
if (datasets_file != "") {
|
||||
ifstream catalog(datasets_file);
|
||||
if (catalog.is_open()) {
|
||||
std::string line;
|
||||
while (getline(catalog, line)) {
|
||||
if (line.empty() || line[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
if (!datasets.isDataset(line)) {
|
||||
cerr << "Dataset " << line << " not found" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
filesToTest.push_back(line);
|
||||
}
|
||||
catalog.close();
|
||||
saveResults = true;
|
||||
if (title == "") {
|
||||
title = "Test " + to_string(filesToTest.size()) + " datasets (" + datasets_file + ") "\
|
||||
+ model_name + " " + to_string(n_folds) + " folds";
|
||||
}
|
||||
} else {
|
||||
throw std::invalid_argument("Unable to open catalog file. [" + datasets_file + "]");
|
||||
}
|
||||
} else {
|
||||
if (file_names.size() > 0) {
|
||||
for (auto file : file_names) {
|
||||
if (!datasets.isDataset(file)) {
|
||||
cerr << "Dataset " << file << " not found" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
filesToTest = file_names;
|
||||
saveResults = true;
|
||||
if (title == "") {
|
||||
title = "Test " + to_string(file_names.size()) + " datasets " + model_name + " " + to_string(n_folds) + " folds";
|
||||
}
|
||||
} else {
|
||||
if (file_name != "all") {
|
||||
if (!datasets.isDataset(file_name)) {
|
||||
cerr << "Dataset " << file_name << " not found" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
if (title == "") {
|
||||
title = "Test " + file_name + " " + model_name + " " + to_string(n_folds) + " folds";
|
||||
}
|
||||
filesToTest.push_back(file_name);
|
||||
} else {
|
||||
filesToTest = datasets.getNames();
|
||||
saveResults = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
platform::HyperParameters test_hyperparams;
|
||||
if (hyperparameters_file != "") {
|
||||
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file, hyper_best);
|
||||
} else {
|
||||
test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_json);
|
||||
}
|
||||
this->config.model = model_name;
|
||||
this->config.score = score;
|
||||
this->config.discretize = discretize_dataset;
|
||||
this->config.stratified = stratified;
|
||||
this->config.smooth_strategy = smooth_strat;
|
||||
this->config.n_folds = n_folds;
|
||||
this->config.seeds = seeds;
|
||||
this->config.quiet = false;
|
||||
auto env = platform::DotEnv();
|
||||
experiment.setTitle(title).setLanguage("c++").setLanguageVersion("gcc 14.1.1");
|
||||
experiment.setDiscretizationAlgorithm(discretize_algo).setSmoothSrategy(smooth_strat);
|
||||
experiment.setDiscretized(discretize_dataset).setModel(model_name).setPlatform(env.get("platform"));
|
||||
experiment.setStratified(stratified).setNFolds(n_folds).setScoreName(score);
|
||||
experiment.setHyperparameters(test_hyperparams);
|
||||
for (auto seed : seeds) {
|
||||
experiment.addRandomSeed(seed);
|
||||
}
|
||||
experiment = arguments.initializedExperiment();
|
||||
this->config.model = experiment.getModel();
|
||||
this->config.score = experiment.getScore();
|
||||
this->config.discretize = experiment.isDiscretized();
|
||||
this->config.stratified = experiment.isStratified();
|
||||
this->config.smooth_strategy = experiment.getSmoothStrategy();
|
||||
this->config.n_folds = experiment.getNFolds();
|
||||
this->config.seeds = experiment.getRandomSeeds();
|
||||
this->config.quiet = experiment.isQuiet();
|
||||
}
|
||||
json GridExperiment::getResults()
|
||||
{
|
||||
|
@@ -9,6 +9,7 @@
|
||||
#include "common/DotEnv.h"
|
||||
#include "main/Experiment.h"
|
||||
#include "main/HyperParameters.h"
|
||||
#include "main/ArgumentsExperiment.h"
|
||||
#include "GridData.h"
|
||||
#include "GridBase.h"
|
||||
#include "bayesnet/network/Network.h"
|
||||
@@ -18,14 +19,14 @@ namespace platform {
|
||||
using json = nlohmann::ordered_json;
|
||||
class GridExperiment : public GridBase {
|
||||
public:
|
||||
explicit GridExperiment(argparse::ArgumentParser& program, struct ConfigGrid& config);
|
||||
explicit GridExperiment(ArgumentsExperiment& program, struct ConfigGrid& config);
|
||||
~GridExperiment() = default;
|
||||
json getResults();
|
||||
Experiment& getExperiment() { return experiment; }
|
||||
size_t numFiles() const { return filesToTest.size(); }
|
||||
bool haveToSaveResults() const { return saveResults; }
|
||||
private:
|
||||
argparse::ArgumentParser& arguments;
|
||||
ArgumentsExperiment& arguments;
|
||||
Experiment experiment;
|
||||
json computed_results;
|
||||
bool saveResults;
|
||||
|
Reference in New Issue
Block a user