Add hyperparameters management in experiments

This commit is contained in:
2023-08-20 17:57:38 +02:00
parent 7a6ec73d63
commit 4964aab722
17 changed files with 141 additions and 117 deletions

View File

@@ -29,7 +29,8 @@ namespace platform {
};
class Result {
private:
string dataset, hyperparameters, model_version;
string dataset, model_version;
json hyperparameters;
int samples{ 0 }, features{ 0 }, classes{ 0 };
double score_train{ 0 }, score_test{ 0 }, score_train_std{ 0 }, score_test_std{ 0 }, train_time{ 0 }, train_time_std{ 0 }, test_time{ 0 }, test_time_std{ 0 };
float nodes{ 0 }, leaves{ 0 }, depth{ 0 };
@@ -37,7 +38,7 @@ namespace platform {
public:
Result() = default;
Result& setDataset(const string& dataset) { this->dataset = dataset; return *this; }
Result& setHyperparameters(const string& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
Result& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
Result& setSamples(int samples) { this->samples = samples; return *this; }
Result& setFeatures(int features) { this->features = features; return *this; }
Result& setClasses(int classes) { this->classes = classes; return *this; }
@@ -59,7 +60,7 @@ namespace platform {
const float get_score_train() const { return score_train; }
float get_score_test() { return score_test; }
const string& getDataset() const { return dataset; }
const string& getHyperparameters() const { return hyperparameters; }
const json& getHyperparameters() const { return hyperparameters; }
const int getSamples() const { return samples; }
const int getFeatures() const { return features; }
const int getClasses() const { return classes; }
@@ -85,11 +86,12 @@ namespace platform {
bool discretized{ false }, stratified{ false };
vector<Result> results;
vector<int> randomSeeds;
json hyperparameters = "{}";
int nfolds{ 0 };
float duration{ 0 };
json build_json();
public:
Experiment() = default;
Experiment();
Experiment& setTitle(const string& title) { this->title = title; return *this; }
Experiment& setModel(const string& model) { this->model = model; return *this; }
Experiment& setPlatform(const string& platform) { this->platform = platform; return *this; }
@@ -103,6 +105,7 @@ namespace platform {
Experiment& addResult(Result result) { results.push_back(result); return *this; }
Experiment& addRandomSeed(int randomSeed) { randomSeeds.push_back(randomSeed); return *this; }
Experiment& setDuration(float duration) { this->duration = duration; return *this; }
Experiment& setHyperparameters(const json& hyperparameters) { this->hyperparameters = hyperparameters; return *this; }
string get_file_name();
void save(const string& path);
void cross_validation(const string& path, const string& fileName);