From e336d39cfb9902053b0ce557943a219fa6a84af2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 30 Jan 2024 13:11:16 +0100 Subject: [PATCH] Add --no-train-score to b_main --- src/Platform/b_main.cc | 6 ++++-- src/Platform/modules/Experiment.cc | 10 ++++++---- src/Platform/modules/Experiment.h | 4 ++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/Platform/b_main.cc b/src/Platform/b_main.cc index 7872239..0f10a65 100644 --- a/src/Platform/b_main.cc +++ b/src/Platform/b_main.cc @@ -31,6 +31,7 @@ void manageArguments(argparse::ArgumentParser& program) ); program.add_argument("--title").default_value("").help("Experiment title"); program.add_argument("--discretize").help("Discretize input dataset").default_value((bool)stoi(env.get("discretize"))).implicit_value(true); + program.add_argument("--no-train-score").help("Don't compute train score").default_value(false).implicit_value(true); program.add_argument("--quiet").help("Don't display detailed progress").default_value(false).implicit_value(true); program.add_argument("--save").help("Save result (always save if no dataset is supplied)").default_value(false).implicit_value(true); program.add_argument("--stratified").help("If Stratified KFold is to be done").default_value((bool)stoi(env.get("stratified"))).implicit_value(true); @@ -58,7 +59,7 @@ int main(int argc, char** argv) manageArguments(program); std::string file_name, model_name, title, hyperparameters_file; json hyperparameters_json; - bool discretize_dataset, stratified, saveResults, quiet; + bool discretize_dataset, stratified, saveResults, quiet, no_train_score; std::vector seeds; std::vector filesToTest; int n_folds; @@ -74,6 +75,7 @@ int main(int argc, char** argv) auto hyperparameters = program.get("hyperparameters"); hyperparameters_json = json::parse(hyperparameters); hyperparameters_file = program.get("hyper-file"); + no_train_score = program.get("no-train-score"); if (hyperparameters_file != "" && hyperparameters != "{}") { throw runtime_error("hyperparameters and hyper_file are mutually exclusive"); } @@ -123,7 +125,7 @@ int main(int argc, char** argv) } platform::Timer timer; timer.start(); - experiment.go(filesToTest, quiet); + experiment.go(filesToTest, quiet, no_train_score); experiment.setDuration(timer.getDuration()); if (saveResults) { experiment.save(platform::Paths::results()); diff --git a/src/Platform/modules/Experiment.cc b/src/Platform/modules/Experiment.cc index a48a290..1b7d891 100644 --- a/src/Platform/modules/Experiment.cc +++ b/src/Platform/modules/Experiment.cc @@ -101,7 +101,7 @@ namespace platform { std::cout << data.dump(4) << std::endl; } - void Experiment::go(std::vector filesToProcess, bool quiet) + void Experiment::go(std::vector filesToProcess, bool quiet, bool no_train_score) { for (auto fileName : filesToProcess) { if (fileName.size() > max_name) @@ -122,7 +122,7 @@ namespace platform { for (auto fileName : filesToProcess) { if (!quiet) std::cout << " " << setw(3) << right << num++ << " " << setw(max_name) << left << fileName << right << flush; - cross_validation(fileName, quiet); + cross_validation(fileName, quiet, no_train_score); if (!quiet) std::cout << std::endl; } @@ -150,7 +150,7 @@ namespace platform { std::cout << prefix << color << fold << Colors::RESET() << "(" << color << phase << Colors::RESET() << ")" << flush; } - void Experiment::cross_validation(const std::string& fileName, bool quiet) + void Experiment::cross_validation(const std::string& fileName, bool quiet, bool no_train_score) { auto datasets = Datasets(discretized, Paths::datasets()); // Get dataset @@ -218,8 +218,10 @@ namespace platform { edges[item] = clf->getNumberOfEdges(); num_states[item] = clf->getNumberOfStates(); train_time[item] = train_timer.getDuration(); + double accuracy_train_value = 0.0; // Score train - auto accuracy_train_value = clf->score(X_train, y_train); + if (!no_train_score) + accuracy_train_value = clf->score(X_train, y_train); // Test model if (!quiet) showProgress(nfold + 1, getColor(clf->getStatus()), "c"); diff --git a/src/Platform/modules/Experiment.h b/src/Platform/modules/Experiment.h index c0a1570..252c7c9 100644 --- a/src/Platform/modules/Experiment.h +++ b/src/Platform/modules/Experiment.h @@ -85,8 +85,8 @@ namespace platform { Experiment& setHyperparameters(const HyperParameters& hyperparameters_) { this->hyperparameters = hyperparameters_; return *this; } std::string get_file_name(); void save(const std::string& path); - void cross_validation(const std::string& fileName, bool quiet); - void go(std::vector filesToProcess, bool quiet); + void cross_validation(const std::string& fileName, bool quiet, bool no_train_score); + void go(std::vector filesToProcess, bool quiet, bool no_train_score); void show(); void report(); private: