diff --git a/src/best/BestResults.cpp b/src/best/BestResults.cpp index 9f2d779..c0b8622 100644 --- a/src/best/BestResults.cpp +++ b/src/best/BestResults.cpp @@ -6,6 +6,7 @@ #include #include "common/Colors.h" #include "common/CLocale.h" +#include "common/Paths.h" #include "results/Result.h" #include "BestResultsExcel.h" #include "best/Statistics.h" @@ -59,16 +60,12 @@ namespace platform { std::cerr << Colors::MAGENTA() << "No results found for model " << model << " and score " << score << Colors::RESET() << std::endl; exit(1); } - std::string bestFileName = path + bestResultFile(); + std::string bestFileName = path + Paths::bestResultsFile(score, model); std::ofstream file(bestFileName); file << bests; file.close(); return bestFileName; } - std::string BestResults::bestResultFile() - { - return "best_results_" + score + "_" + model + ".json"; - } std::pair getModelScore(std::string name) { // results_accuracy_BoostAODE_MacBookpro16_2023-09-06_12:27:00_1.json @@ -150,7 +147,7 @@ namespace platform { } void BestResults::listFile() { - std::string bestFileName = path + bestResultFile(); + std::string bestFileName = path + Paths::bestResultsFile(score, model); if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) { fclose(fileTest); } else { @@ -196,7 +193,7 @@ namespace platform { auto maxDate = std::filesystem::file_time_type::max(); for (const auto& model : models) { this->model = model; - std::string bestFileName = path + bestResultFile(); + std::string bestFileName = path + Paths::bestResultsFile(score, model); if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) { fclose(fileTest); } else { @@ -306,7 +303,7 @@ namespace platform { json table = buildTableResults(models); std::vector datasets = getDatasets(table.begin().value()); BestResultsExcel excel_report(score, datasets); - excel_report.reportSingle(model, path + bestResultFile()); + excel_report.reportSingle(model, path + Paths::bestResultsFile(score, model)); messageExcelFile(excel_report.getFileName()); } } @@ -346,7 +343,7 @@ namespace platform { } } model = models.at(idx); - excel.reportSingle(model, path + bestResultFile()); + excel.reportSingle(model, path + Paths::bestResultsFile(score, model)); } messageExcelFile(excel.getFileName()); } diff --git a/src/best/BestResults.h b/src/best/BestResults.h index 2815e97..2098dea 100644 --- a/src/best/BestResults.h +++ b/src/best/BestResults.h @@ -22,7 +22,6 @@ namespace platform { void messageExcelFile(const std::string& fileName); json buildTableResults(std::vector models); void printTableResults(std::vector models, json table); - std::string bestResultFile(); json loadFile(const std::string& fileName); void listFile(); std::string path; diff --git a/src/commands/b_main.cpp b/src/commands/b_main.cpp index e4b0ea1..ab1de9a 100644 --- a/src/commands/b_main.cpp +++ b/src/commands/b_main.cpp @@ -36,6 +36,7 @@ void manageArguments(argparse::ArgumentParser& program) program.add_argument("--hyperparameters").default_value("{}").help("Hyperparameters passed to the model in Experiment"); program.add_argument("--hyper-file").default_value("").help("Hyperparameters file name." \ "Mutually exclusive with hyperparameters. This file should contain hyperparameters for each dataset in json format."); + program.add_argument("--hyper-best").default_value(false).help("Use best results of the model as source of hyperparameters").implicit_value(true); program.add_argument("-m", "--model") .help("Model to use: " + platform::Models::instance()->toString()) .action([](const std::string& value) { @@ -93,7 +94,7 @@ int main(int argc, char** argv) manageArguments(program); std::string file_name, model_name, title, hyperparameters_file, datasets_file, discretize_algo, smooth_strat, score; json hyperparameters_json; - bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph; + bool discretize_dataset, stratified, saveResults, quiet, no_train_score, generate_fold_files, graph, hyper_best; std::vector seeds; std::vector file_names; std::vector filesToTest; @@ -117,9 +118,17 @@ int main(int argc, char** argv) hyperparameters_json = json::parse(hyperparameters); hyperparameters_file = program.get("hyper-file"); no_train_score = program.get("no-train-score"); + hyper_best = program.get("hyper-best"); generate_fold_files = program.get("generate-fold-files"); - if (hyperparameters_file != "" && hyperparameters != "{}") { - throw runtime_error("hyperparameters and hyper_file are mutually exclusive"); + if (hyper_best) { + // Build the best results file_name + hyperparameters_file = platform::Paths::results() + platform::Paths::bestResultsFile(score, model_name); + // ignore this parameter + hyperparameters = "{}"; + } else { + if (hyperparameters_file != "" && hyperparameters != "{}") { + throw runtime_error("hyperparameters and hyper_file are mutually exclusive"); + } } title = program.get("title"); if (title == "" && file_name == "all") { @@ -188,7 +197,7 @@ int main(int argc, char** argv) platform::HyperParameters test_hyperparams; if (hyperparameters_file != "") { - test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file); + test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_file, hyper_best); } else { test_hyperparams = platform::HyperParameters(datasets.getNames(), hyperparameters_json); } diff --git a/src/common/Paths.h b/src/common/Paths.h index 6c6e471..f8b8f8e 100644 --- a/src/common/Paths.h +++ b/src/common/Paths.h @@ -32,6 +32,10 @@ namespace platform { throw std::runtime_error("Could not create directory " + path); } } + static std::string bestResultsFile(const std::string& score, const std::string& model) + { + return "best_results_" + score + "_" + model + ".json"; + } static std::string excelResults() { return "some_results.xlsx"; } static std::string grid_input(const std::string& model) { diff --git a/src/main/HyperParameters.cpp b/src/main/HyperParameters.cpp index f10113e..719a921 100644 --- a/src/main/HyperParameters.cpp +++ b/src/main/HyperParameters.cpp @@ -19,7 +19,7 @@ namespace platform { std::ostream_iterator(ss, delim.c_str())); return ss.str(); } - HyperParameters::HyperParameters(const std::vector& datasets, const std::string& hyperparameters_file) + HyperParameters::HyperParameters(const std::vector& datasets, const std::string& hyperparameters_file, bool best) { // Check if file exists std::ifstream file(hyperparameters_file); @@ -28,7 +28,14 @@ namespace platform { } // Check if file is a json json file_hyperparameters = json::parse(file); - auto input_hyperparameters = file_hyperparameters["results"]; + json input_hyperparameters; + if (best) { + for (const auto& [key, value] : file_hyperparameters.items()) { + input_hyperparameters[key] = value[1]; + } + } else { + input_hyperparameters = file_hyperparameters["results"]; + } // Check if hyperparameters are valid for (const auto& dataset : datasets) { if (!input_hyperparameters.contains(dataset)) { diff --git a/src/main/HyperParameters.h b/src/main/HyperParameters.h index 2a1aec0..ed1b948 100644 --- a/src/main/HyperParameters.h +++ b/src/main/HyperParameters.h @@ -10,14 +10,17 @@ namespace platform { class HyperParameters { public: HyperParameters() = default; + // Constructor to use command line hyperparameters explicit HyperParameters(const std::vector& datasets, const json& hyperparameters_); - explicit HyperParameters(const std::vector& datasets, const std::string& hyperparameters_file); + // Constructor to use hyperparameters file generated by grid or by best results + explicit HyperParameters(const std::vector& datasets, const std::string& hyperparameters_file, bool best = false); ~HyperParameters() = default; bool notEmpty(const std::string& key) const { return !hyperparameters.at(key).empty(); } void check(const std::vector& valid, const std::string& fileName); json get(const std::string& fileName); private: std::map hyperparameters; + bool best = false; // Used to separate grid/best hyperparameters as the format of those files are different }; } /* namespace platform */ #endif \ No newline at end of file