diff --git a/src/Platform/BestResults.cc b/src/Platform/BestResults.cc index b5ea179..3220382 100644 --- a/src/Platform/BestResults.cc +++ b/src/Platform/BestResults.cc @@ -9,7 +9,7 @@ namespace platform { string BestResults::build() { - auto files = loadFiles(); + auto files = loadResultFiles(); if (files.size() == 0) { cerr << Colors::MAGENTA() << "No result files were found!" << Colors::RESET() << endl; exit(1); @@ -48,7 +48,7 @@ namespace platform { return "best_results_" + model + "_" + score + ".json"; } - vector BestResults::loadFiles() + vector BestResults::loadResultFiles() { vector files; using std::filesystem::directory_iterator; @@ -56,7 +56,7 @@ namespace platform { auto fileName = file.path().filename().string(); if (fileName.find(".json") != string::npos && fileName.find("results_") == 0 && fileName.find("_" + score + "_") != string::npos - && fileName.find("_" + model + "_") != string::npos) { + && (fileName.find("_" + model + "_") != string::npos || model == "any")) { files.push_back(fileName); } } @@ -71,8 +71,12 @@ namespace platform { } throw invalid_argument("Unable to open result file. [" + fileName + "]"); } + void BestResults::reportAll() + { - void BestResults::report() + } + + void BestResults::reportSingle() { string bestFileName = path + bestResultFile(); if (FILE* fileTest = fopen(bestFileName.c_str(), "r")) { diff --git a/src/Platform/BestResults.h b/src/Platform/BestResults.h index 91bc45a..19e10d7 100644 --- a/src/Platform/BestResults.h +++ b/src/Platform/BestResults.h @@ -9,9 +9,10 @@ namespace platform { public: explicit BestResults(const string& path, const string& score, const string& model) : path(path), score(score), model(model) {} string build(); - void report(); + void reportSingle(); + void reportAll(); private: - vector loadFiles(); + vector loadResultFiles(); string bestResultFile(); json loadFile(const string& fileName); string path; diff --git a/src/Platform/best.cc b/src/Platform/best.cc index e76ccdd..0280a8a 100644 --- a/src/Platform/best.cc +++ b/src/Platform/best.cc @@ -9,7 +9,7 @@ using namespace std; argparse::ArgumentParser manageArguments(int argc, char** argv) { argparse::ArgumentParser program("best"); - program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model)"); + program.add_argument("-m", "--model").default_value("").help("Filter results of the selected model) (any for all models)"); program.add_argument("-s", "--score").default_value("").help("Filter results of the score name supplied"); program.add_argument("--build").help("build best score results file").default_value(false).implicit_value(true); program.add_argument("--report").help("report of best score results file").default_value(false).implicit_value(true); @@ -43,13 +43,22 @@ int main(int argc, char** argv) cerr << program; exit(1); } + if (model == "any" && build) { + cerr << "Can't build best results file for all models. \"any\" is only valid for report" << endl; + cerr << program; + exit(1); + } auto results = platform::BestResults(platform::Paths::results(), model, score); if (build) { string fileName = results.build(); cout << Colors::GREEN() << fileName << " created!" << Colors::RESET() << endl; } if (report) { - results.report(); + if (model == "any") { + results.reportAll(); + } else { + results.reportSingle(); + } } return 0; }