From 2729b92f065b30f2da3d57049473bf2515a983e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sun, 13 Aug 2023 16:19:17 +0200 Subject: [PATCH 1/5] Summary list --- .vscode/launch.json | 11 +++++++ src/Platform/CMakeLists.txt | 4 ++- src/Platform/Results.cc | 60 +++++++++++++++++++++++++++++++++++++ src/Platform/Results.h | 38 +++++++++++++++++++++++ src/Platform/main.cc | 2 +- src/Platform/manage.cc | 32 ++++++++++++++++++++ 6 files changed, 145 insertions(+), 2 deletions(-) create mode 100644 src/Platform/Results.cc create mode 100644 src/Platform/Results.h create mode 100644 src/Platform/manage.cc diff --git a/.vscode/launch.json b/.vscode/launch.json index ba01ca6..0a7a483 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -34,6 +34,17 @@ ], "cwd": "/Users/rmontanana/Code/discretizbench", }, + { + "type": "lldb", + "request": "launch", + "name": "manage", + "program": "${workspaceFolder}/build/src/Platform/manage", + "args": [ + "-n", + "20" + ], + "cwd": "/Users/rmontanana/Code/discretizbench", + }, { "name": "Build & debug active file", "type": "cppdbg", diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index 3b13abc..0eb26ce 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -5,4 +5,6 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc Report.cc) -target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file +add_executable(manage manage.cc Results.cc Report.cc) +target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") +target_link_libraries(manage "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc new file mode 100644 index 0000000..ee5315b --- /dev/null +++ b/src/Platform/Results.cc @@ -0,0 +1,60 @@ +#include +#include "platformUtils.h" +#include "Results.h" +namespace platform { + const double REFERENCE_SCORE = 22.109799; + Result::Result(const string& path, const string& filename) + : path(path) + , filename(filename) + { + auto data = load(); + date = data["date"]; + score = 0; + for (const auto& result : data["results"]) { + score += result["score"].get(); + } + score /= REFERENCE_SCORE; + title = data["title"]; + duration = data["duration"]; + model = data["model"]; + } + json Result::load() + { + ifstream resultData(path + "/" + filename); + if (resultData.is_open()) { + json data = json::parse(resultData); + return data; + } + throw invalid_argument("Unable to open result file. [" + path + "/" + filename + "]"); + } + void Results::load() + { + using std::filesystem::directory_iterator; + for (const auto& file : directory_iterator(path)) { + auto filename = file.path().filename().string(); + if (filename.find(".json") != string::npos && filename.find("results_") == 0) { + auto result = Result(path, filename); + files.push_back(result); + } + } + } + string Result::to_string() const + { + stringstream oss; + oss << date << " "; + oss << setw(12) << left << model << " "; + oss << right << setw(9) << setprecision(7) << fixed << score << " "; + oss << setw(9) << setprecision(3) << fixed << duration << " "; + oss << setw(50) << left << title << " "; + return oss.str(); + } + void Results::manage() + { + cout << "Results found: " << files.size() << endl; + cout << "========================" << endl; + for (const auto& result : files) { + cout << result.to_string() << endl; + } + } + +} \ No newline at end of file diff --git a/src/Platform/Results.h b/src/Platform/Results.h new file mode 100644 index 0000000..5d36f32 --- /dev/null +++ b/src/Platform/Results.h @@ -0,0 +1,38 @@ +#ifndef RESULTS_H +#define RESULTS_H +#include +#include +#include +#include +namespace platform { + using namespace std; + using json = nlohmann::json; + + class Result { + public: + Result(const string& path, const string& filename); + json load(); + string to_string() const; + private: + string path; + string filename; + string date; + double score; + string title; + double duration; + string model; + }; + class Results { + public: + explicit Results(const string& path) : path(path) { load(); }; + void manage(); + private: + string path; + vector files; + void load(); // Loads the list of results + void show(); + int menu(); + }; +}; + +#endif \ No newline at end of file diff --git a/src/Platform/main.cc b/src/Platform/main.cc index 24d0a33..7692629 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -14,7 +14,7 @@ const string PATH_DATASETS = "datasets"; argparse::ArgumentParser manageArguments(int argc, char** argv) { auto env = platform::DotEnv(); - argparse::ArgumentParser program("BayesNetSample"); + argparse::ArgumentParser program("main"); program.add_argument("-d", "--dataset").default_value("").help("Dataset file name"); program.add_argument("-p", "--path") .help("folder where the data files are located, default") diff --git a/src/Platform/manage.cc b/src/Platform/manage.cc new file mode 100644 index 0000000..b901601 --- /dev/null +++ b/src/Platform/manage.cc @@ -0,0 +1,32 @@ +#include +#include +#include "platformUtils.h" +#include "Results.h" + +using namespace std; +const string PATH_RESULTS = "results"; + +argparse::ArgumentParser manageArguments(int argc, char** argv) +{ + argparse::ArgumentParser program("manage"); + program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>(); + try { + program.parse_args(argc, argv); + auto number = program.get("number"); + } + catch (const exception& err) { + cerr << err.what() << endl; + cerr << program; + exit(1); + } + return program; +} + +int main(int argc, char** argv) +{ + auto program = manageArguments(argc, argv); + auto number = program.get("number"); + auto results = platform::Results(PATH_RESULTS); + results.manage(); + return 0; +} -- 2.45.2 From 054567c65a252dafb6090a85918bc3c22d4038bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sun, 13 Aug 2023 17:10:18 +0200 Subject: [PATCH 2/5] Add sorting capacity --- src/Platform/Results.cc | 132 +++++++++++++++++++++++++++++++++++++++- src/Platform/Results.h | 19 +++++- src/Platform/manage.cc | 5 +- 3 files changed, 150 insertions(+), 6 deletions(-) diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index ee5315b..056f0b1 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -1,6 +1,7 @@ #include #include "platformUtils.h" #include "Results.h" +#include "Report.h" namespace platform { const double REFERENCE_SCORE = 22.109799; Result::Result(const string& path, const string& filename) @@ -48,13 +49,140 @@ namespace platform { oss << setw(50) << left << title << " "; return oss.str(); } - void Results::manage() + void Results::show() const { cout << "Results found: " << files.size() << endl; - cout << "========================" << endl; + cout << "-------------------" << endl; + auto i = 0; + cout << " # Date Model Score Duration Title" << endl; + cout << "=== ========== ============ ========= ========= =============================================================" << endl; for (const auto& result : files) { + cout << setw(3) << fixed << right << i++ << " "; cout << result.to_string() << endl; + if (i == max && max != 0) { + break; + } + } } + int Results::getIndex(const string& intent) const + { + cout << "Choose result to " << intent << ": "; + int index; + cin >> index; + if (index >= 0 && index < files.size()) { + return index; + } + + cout << "Invalid index" << endl; + return -1; + } + void Results::menu() + { + cout << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r'): "; + char option; + int index; + string filename; + cin >> option; + switch (option) { + case 'q': + exit(0); + case 'l': + show(); + menu(); + break; + case 'd': + index = getIndex("delete"); + if (index == -1) + break; + filename = files[index].getFilename(); + cout << "Deleting " << filename << endl; + remove((path + "/" + filename).c_str()); + files.erase(files.begin() + index); + show(); + menu(); + break; + case 'h': + index = getIndex("hide"); + if (index == -1) + break; + filename = files[index].getFilename(); + cout << "Hiding " << filename << endl; + rename((path + "/" + filename).c_str(), (path + "/." + filename).c_str()); + files.erase(files.begin() + index); + show(); + menu(); + break; + case 's': + sortList(); + show(); + menu(); + break; + case 'r': + index = getIndex("report"); + if (index == -1) + break; + filename = files[index].getFilename(); + cout << "Reporting " << filename << endl; + auto data = files[index].load(); + Report report(data); + report.show(); + menu(); + break; + + } + } + void Results::sortList() + { + cout << "Choose sorting field (date='d', score='s', duration='u', model='m'): "; + char option; + cin >> option; + switch (option) { + case 'd': + sortDate(); + break; + case 's': + sortScore(); + break; + case 'u': + sortDuration(); + break; + case 'm': + sortModel(); + break; + default: + cout << "Invalid option" << endl; + } + + } + void Results::sortDate() + { + sort(files.begin(), files.end(), [](const Result& a, const Result& b) { + return a.getDate() > b.getDate(); + }); + } + void Results::sortModel() + { + sort(files.begin(), files.end(), [](const Result& a, const Result& b) { + return a.getModel() > b.getModel(); + }); + } + void Results::sortDuration() + { + sort(files.begin(), files.end(), [](const Result& a, const Result& b) { + return a.getDuration() > b.getDuration(); + }); + } + void Results::sortScore() + { + sort(files.begin(), files.end(), [](const Result& a, const Result& b) { + return a.getScore() > b.getScore(); + }); + } + void Results::manage() + { + show(); + menu(); + } } \ No newline at end of file diff --git a/src/Platform/Results.h b/src/Platform/Results.h index 5d36f32..945901f 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -13,6 +13,12 @@ namespace platform { Result(const string& path, const string& filename); json load(); string to_string() const; + string getFilename() const { return filename; }; + string getDate() const { return date; }; + double getScore() const { return score; }; + string getTitle() const { return title; }; + double getDuration() const { return duration; }; + string getModel() const { return model; }; private: string path; string filename; @@ -24,14 +30,21 @@ namespace platform { }; class Results { public: - explicit Results(const string& path) : path(path) { load(); }; + explicit Results(const string& path, const int max) : path(path), max(max) { load(); }; void manage(); private: string path; + int max; vector files; void load(); // Loads the list of results - void show(); - int menu(); + void show() const; + int getIndex(const string& intent) const; + void menu(); + void sortList(); + void sortDate(); + void sortScore(); + void sortModel(); + void sortDuration(); }; }; diff --git a/src/Platform/manage.cc b/src/Platform/manage.cc index b901601..f97dae3 100644 --- a/src/Platform/manage.cc +++ b/src/Platform/manage.cc @@ -13,6 +13,9 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) try { program.parse_args(argc, argv); auto number = program.get("number"); + if (number < 0) { + throw runtime_error("Number of results must be greater than or equal to 0"); + } } catch (const exception& err) { cerr << err.what() << endl; @@ -26,7 +29,7 @@ int main(int argc, char** argv) { auto program = manageArguments(argc, argv); auto number = program.get("number"); - auto results = platform::Results(PATH_RESULTS); + auto results = platform::Results(PATH_RESULTS, number); results.manage(); return 0; } -- 2.45.2 From 3691cb4a614b96d0e0145fadb8d365c25f2a23fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sun, 13 Aug 2023 18:13:00 +0200 Subject: [PATCH 3/5] Add totals and filter by scoreName and model --- src/Platform/BestResult.h | 10 ++++++++++ src/Platform/Report.cc | 26 ++++++++++++++++++++++---- src/Platform/Report.h | 2 ++ src/Platform/Results.cc | 24 ++++++++++++++++++------ src/Platform/Results.h | 6 +++++- src/Platform/manage.cc | 8 +++++++- 6 files changed, 64 insertions(+), 12 deletions(-) create mode 100644 src/Platform/BestResult.h diff --git a/src/Platform/BestResult.h b/src/Platform/BestResult.h new file mode 100644 index 0000000..8b3f1cb --- /dev/null +++ b/src/Platform/BestResult.h @@ -0,0 +1,10 @@ +#ifndef BESTRESULT_H +#define BESTRESULT_H +#include +class BestResult { +public: + static std::string title() { return "STree_default (linear-ovo)"; } + static double score() { return 22.109799; } + static std::string scoreName() { return "accuracy"; } +}; +#endif \ No newline at end of file diff --git a/src/Platform/Report.cc b/src/Platform/Report.cc index 3693248..7bd7d69 100644 --- a/src/Platform/Report.cc +++ b/src/Platform/Report.cc @@ -1,4 +1,5 @@ #include "Report.h" +#include "BestResult.h" namespace platform { string headerLine(const string& text) @@ -28,6 +29,7 @@ namespace platform { { header(); body(); + footer(); } void Report::header() { @@ -44,6 +46,8 @@ namespace platform { { cout << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; cout << "============================== ====== ===== === ======= ======= ======= =============== ================= ===============" << endl; + json lastResult; + totalScore = 0; for (const auto& r : data["results"]) { cout << setw(30) << left << r["dataset"].get() << " "; cout << setw(6) << right << r["samples"].get() << " "; @@ -56,12 +60,26 @@ namespace platform { cout << setw(10) << right << setprecision(6) << fixed << r["test_time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["test_time_std"].get() << " "; cout << " " << r["hyperparameters"].get(); cout << endl; + lastResult = r; + totalScore += r["score_test"].get(); + } + if (data["results"].size() == 1) { cout << string(MAXL, '*') << endl; - cout << headerLine("Train scores: " + fVector(r["scores_train"])); - cout << headerLine("Test scores: " + fVector(r["scores_test"])); - cout << headerLine("Train times: " + fVector(r["times_train"])); - cout << headerLine("Test times: " + fVector(r["times_test"])); + cout << headerLine("Train scores: " + fVector(lastResult["scores_train"])); + cout << headerLine("Test scores: " + fVector(lastResult["scores_test"])); + cout << headerLine("Train times: " + fVector(lastResult["times_train"])); + cout << headerLine("Test times: " + fVector(lastResult["times_test"])); cout << string(MAXL, '*') << endl; } } + void Report::footer() + { + cout << string(MAXL, '*') << endl; + auto score = data["score_name"].get(); + if (score == BestResult::scoreName()) { + cout << headerLine(score + " compared to " + BestResult::title() + " .: " + to_string(totalScore / BestResult::score())); + } + cout << string(MAXL, '*') << endl; + + } } \ No newline at end of file diff --git a/src/Platform/Report.h b/src/Platform/Report.h index c6ea8a1..302ac60 100644 --- a/src/Platform/Report.h +++ b/src/Platform/Report.h @@ -16,8 +16,10 @@ namespace platform { private: void header(); void body(); + void footer(); string fromVector(const string& key); json data; + double totalScore; // Total score of all results in a report }; }; #endif \ No newline at end of file diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index 056f0b1..c33cf37 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -2,8 +2,8 @@ #include "platformUtils.h" #include "Results.h" #include "Report.h" +#include "BestResult.h" namespace platform { - const double REFERENCE_SCORE = 22.109799; Result::Result(const string& path, const string& filename) : path(path) , filename(filename) @@ -14,7 +14,10 @@ namespace platform { for (const auto& result : data["results"]) { score += result["score"].get(); } - score /= REFERENCE_SCORE; + scoreName = data["score_name"]; + if (scoreName == BestResult::scoreName()) { + score /= BestResult::score(); + } title = data["title"]; duration = data["duration"]; model = data["model"]; @@ -35,7 +38,11 @@ namespace platform { auto filename = file.path().filename().string(); if (filename.find(".json") != string::npos && filename.find("results_") == 0) { auto result = Result(path, filename); - files.push_back(result); + bool addResult = true; + if (model != "any" && result.getModel() != model || scoreName != "any" && scoreName != result.getScoreName()) + addResult = false; + if (addResult) + files.push_back(result); } } } @@ -44,7 +51,8 @@ namespace platform { stringstream oss; oss << date << " "; oss << setw(12) << left << model << " "; - oss << right << setw(9) << setprecision(7) << fixed << score << " "; + oss << setw(11) << left << scoreName << " "; + oss << right << setw(11) << setprecision(7) << fixed << score << " "; oss << setw(9) << setprecision(3) << fixed << duration << " "; oss << setw(50) << left << title << " "; return oss.str(); @@ -54,8 +62,8 @@ namespace platform { cout << "Results found: " << files.size() << endl; cout << "-------------------" << endl; auto i = 0; - cout << " # Date Model Score Duration Title" << endl; - cout << "=== ========== ============ ========= ========= =============================================================" << endl; + cout << " # Date Model Score Name Score Duration Title" << endl; + cout << "=== ========== ============ =========== =========== ========= =============================================================" << endl; for (const auto& result : files) { cout << setw(3) << fixed << right << i++ << " "; cout << result.to_string() << endl; @@ -181,6 +189,10 @@ namespace platform { } void Results::manage() { + if (files.size() == 0) { + cout << "No results found!" << endl; + exit(0); + } show(); menu(); } diff --git a/src/Platform/Results.h b/src/Platform/Results.h index 945901f..bd4768b 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -19,6 +19,7 @@ namespace platform { string getTitle() const { return title; }; double getDuration() const { return duration; }; string getModel() const { return model; }; + string getScoreName() const { return scoreName; }; private: string path; string filename; @@ -27,14 +28,17 @@ namespace platform { string title; double duration; string model; + string scoreName; }; class Results { public: - explicit Results(const string& path, const int max) : path(path), max(max) { load(); }; + Results(const string& path, const int max, const string& model, const string& score) : path(path), max(max), model(model), scoreName(score) { load(); }; void manage(); private: string path; int max; + string model; + string scoreName; vector files; void load(); // Loads the list of results void show() const; diff --git a/src/Platform/manage.cc b/src/Platform/manage.cc index f97dae3..74e4a2c 100644 --- a/src/Platform/manage.cc +++ b/src/Platform/manage.cc @@ -10,12 +10,16 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) { argparse::ArgumentParser program("manage"); program.add_argument("-n", "--number").default_value(0).help("Number of results to show (0 = all)").scan<'i', int>(); + program.add_argument("-m", "--model").default_value("any").help("Filter results of the selected model)"); + program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied"); try { program.parse_args(argc, argv); auto number = program.get("number"); if (number < 0) { throw runtime_error("Number of results must be greater than or equal to 0"); } + auto model = program.get("model"); + auto score = program.get("score"); } catch (const exception& err) { cerr << err.what() << endl; @@ -29,7 +33,9 @@ int main(int argc, char** argv) { auto program = manageArguments(argc, argv); auto number = program.get("number"); - auto results = platform::Results(PATH_RESULTS, number); + auto model = program.get("model"); + auto score = program.get("score"); + auto results = platform::Results(PATH_RESULTS, number, model, score); results.manage(); return 0; } -- 2.45.2 From 55d21294d5460addd17e339729386ae804896dc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 14 Aug 2023 00:40:31 +0200 Subject: [PATCH 4/5] Add class Paths and enhance input --- src/Platform/Paths.h | 10 +++ src/Platform/Results.cc | 129 ++++++++++++++++++++-------------- src/Platform/Results.h | 3 +- src/Platform/main.cc | 9 ++- src/Platform/manage.cc | 4 +- src/Platform/platformUtils.cc | 3 +- 6 files changed, 98 insertions(+), 60 deletions(-) create mode 100644 src/Platform/Paths.h diff --git a/src/Platform/Paths.h b/src/Platform/Paths.h new file mode 100644 index 0000000..756e61a --- /dev/null +++ b/src/Platform/Paths.h @@ -0,0 +1,10 @@ +#ifndef PATHS_H +#define PATHS_H +namespace platform { + class Paths { + public: + static std::string datasets() { return "datasets/"; } + static std::string results() { return "results/"; } + }; +} +#endif \ No newline at end of file diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index c33cf37..48c7e9a 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -22,7 +22,7 @@ namespace platform { duration = data["duration"]; model = data["model"]; } - json Result::load() + json Result::load() const { ifstream resultData(path + "/" + filename); if (resultData.is_open()) { @@ -70,7 +70,6 @@ namespace platform { if (i == max && max != 0) { break; } - } } int Results::getIndex(const string& intent) const @@ -81,70 +80,98 @@ namespace platform { if (index >= 0 && index < files.size()) { return index; } - cout << "Invalid index" << endl; return -1; } + void Results::report(const int index) const + { + cout << "Reporting " << files.at(index).getFilename() << endl; + auto data = files.at(index).load(); + Report report(data); + report.show(); + } void Results::menu() { - cout << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r'): "; char option; int index; - string filename; - cin >> option; - switch (option) { - case 'q': - exit(0); - case 'l': - show(); - menu(); - break; - case 'd': - index = getIndex("delete"); - if (index == -1) + bool finished = false; + string filename, line, options = "qldhsr"; + while (!finished) { + cout << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r'): "; + getline(cin, line); + if (line.size() == 0) + continue; + if (options.find(line[0]) != string::npos) { + if (line.size() > 1) { + cout << "Invalid option" << endl; + continue; + } + option = line[0]; + } else { + index = stoi(line); + if (index >= 0 && index < files.size()) { + report(index); + } else { + cout << "Invalid option" << endl; + } + continue; + } + switch (option) { + case 'q': + finished = true; break; - filename = files[index].getFilename(); - cout << "Deleting " << filename << endl; - remove((path + "/" + filename).c_str()); - files.erase(files.begin() + index); - show(); - menu(); - break; - case 'h': - index = getIndex("hide"); - if (index == -1) + case 'l': + show(); break; - filename = files[index].getFilename(); - cout << "Hiding " << filename << endl; - rename((path + "/" + filename).c_str(), (path + "/." + filename).c_str()); - files.erase(files.begin() + index); - show(); - menu(); - break; - case 's': - sortList(); - show(); - menu(); - break; - case 'r': - index = getIndex("report"); - if (index == -1) + case 'd': + index = getIndex("delete"); + if (index == -1) + break; + filename = files[index].getFilename(); + cout << "Deleting " << filename << endl; + remove((path + "/" + filename).c_str()); + files.erase(files.begin() + index); + show(); break; - filename = files[index].getFilename(); - cout << "Reporting " << filename << endl; - auto data = files[index].load(); - Report report(data); - report.show(); - menu(); - break; - + case 'h': + index = getIndex("hide"); + if (index == -1) + break; + filename = files[index].getFilename(); + cout << "Hiding " << filename << endl; + rename((path + "/" + filename).c_str(), (path + "/." + filename).c_str()); + files.erase(files.begin() + index); + show(); + menu(); + break; + case 's': + sortList(); + show(); + break; + case 'r': + index = getIndex("report"); + if (index == -1) + break; + report(index); + break; + default: + cout << "Invalid option" << endl; + } } } void Results::sortList() { cout << "Choose sorting field (date='d', score='s', duration='u', model='m'): "; + string line; char option; - cin >> option; + getline(cin, line); + if (line.size() == 0) + return; + if (line.size() > 1) { + cout << "Invalid option" << endl; + return; + } + option = line[0]; switch (option) { case 'd': sortDate(); @@ -161,7 +188,6 @@ namespace platform { default: cout << "Invalid option" << endl; } - } void Results::sortDate() { @@ -195,6 +221,7 @@ namespace platform { } show(); menu(); + cout << "Done!" << endl; } } \ No newline at end of file diff --git a/src/Platform/Results.h b/src/Platform/Results.h index bd4768b..e6b1552 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -11,7 +11,7 @@ namespace platform { class Result { public: Result(const string& path, const string& filename); - json load(); + json load() const; string to_string() const; string getFilename() const { return filename; }; string getDate() const { return date; }; @@ -42,6 +42,7 @@ namespace platform { vector files; void load(); // Loads the list of results void show() const; + void report(const int index) const; int getIndex(const string& intent) const; void menu(); void sortList(); diff --git a/src/Platform/main.cc b/src/Platform/main.cc index 7692629..0618c89 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -6,10 +6,10 @@ #include "DotEnv.h" #include "Models.h" #include "modelRegister.h" +#include "Paths.h" + using namespace std; -const string PATH_RESULTS = "results"; -const string PATH_DATASETS = "datasets"; argparse::ArgumentParser manageArguments(int argc, char** argv) { @@ -18,8 +18,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) program.add_argument("-d", "--dataset").default_value("").help("Dataset file name"); program.add_argument("-p", "--path") .help("folder where the data files are located, default") - .default_value(string{ PATH_DATASETS } - ); + .default_value(string{ platform::Paths::datasets() }); program.add_argument("-m", "--model") .help("Model to use " + platform::Models::instance()->toString()) .action([](const std::string& value) { @@ -115,7 +114,7 @@ int main(int argc, char** argv) experiment.go(filesToTest, path); experiment.setDuration(timer.getDuration()); if (saveResults) - experiment.save(PATH_RESULTS); + experiment.save(platform::Paths::results()); else experiment.report(); cout << "Done!" << endl; diff --git a/src/Platform/manage.cc b/src/Platform/manage.cc index 74e4a2c..34e66cd 100644 --- a/src/Platform/manage.cc +++ b/src/Platform/manage.cc @@ -1,10 +1,10 @@ #include #include #include "platformUtils.h" +#include "Paths.h" #include "Results.h" using namespace std; -const string PATH_RESULTS = "results"; argparse::ArgumentParser manageArguments(int argc, char** argv) { @@ -35,7 +35,7 @@ int main(int argc, char** argv) auto number = program.get("number"); auto model = program.get("model"); auto score = program.get("score"); - auto results = platform::Results(PATH_RESULTS, number, model, score); + auto results = platform::Results(platform::Paths::results(), number, model, score); results.manage(); return 0; } diff --git a/src/Platform/platformUtils.cc b/src/Platform/platformUtils.cc index 6fca9d9..74e97fd 100644 --- a/src/Platform/platformUtils.cc +++ b/src/Platform/platformUtils.cc @@ -1,4 +1,5 @@ #include "platformUtils.h" +#include "Paths.h" using namespace torch; @@ -85,7 +86,7 @@ tuple, string, map>> loadData tuple>, vector, vector, string, map>> loadFile(const string& name) { auto handler = ArffFiles(); - handler.load(PATH + static_cast(name) + ".arff"); + handler.load(platform::Paths::datasets() + static_cast(name) + ".arff"); // Get Dataset X, y vector& X = handler.getX(); mdlp::labels_t& y = handler.getY(); -- 2.45.2 From 2a3fc9aa4569d39b9d296868400e038acc547ebe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 14 Aug 2023 17:03:06 +0200 Subject: [PATCH 5/5] Add colors and enhace input control --- src/Platform/Colors.h | 14 ++++++++++++++ src/Platform/Report.cc | 28 ++++++++++++++++++---------- src/Platform/Report.h | 1 + src/Platform/Results.cc | 30 +++++++++++++++++++++--------- 4 files changed, 54 insertions(+), 19 deletions(-) create mode 100644 src/Platform/Colors.h diff --git a/src/Platform/Colors.h b/src/Platform/Colors.h new file mode 100644 index 0000000..7ab2e08 --- /dev/null +++ b/src/Platform/Colors.h @@ -0,0 +1,14 @@ +#ifndef COLORS_H +#define COLORS_H +class Colors { +public: + static std::string MAGENTA() { return "\033[1;35m"; } + static std::string BLUE() { return "\033[1;34m"; } + static std::string CYAN() { return "\033[1;36m"; } + static std::string GREEN() { return "\033[1;32m"; } + static std::string YELLOW() { return "\033[1;33m"; } + static std::string RED() { return "\033[1;31m"; } + static std::string WHITE() { return "\033[1;37m"; } + static std::string RESET() { return "\033[0m"; } +}; +#endif // COLORS_H \ No newline at end of file diff --git a/src/Platform/Report.cc b/src/Platform/Report.cc index 7bd7d69..a40a482 100644 --- a/src/Platform/Report.cc +++ b/src/Platform/Report.cc @@ -33,7 +33,7 @@ namespace platform { } void Report::header() { - cout << string(MAXL, '*') << endl; + cout << Colors::MAGENTA() << string(MAXL, '*') << endl; cout << headerLine("Report " + data["model"].get() + " ver. " + data["version"].get() + " with " + to_string(data["folds"].get()) + " Folds cross validation and " + to_string(data["seeds"].size()) + " random seeds. " + data["date"].get() + " " + data["time"].get()); cout << headerLine(data["title"].get()); cout << headerLine("Random seeds: " + fromVector("seeds") + " Stratified: " + (data["stratified"].get() ? "True" : "False")); @@ -44,24 +44,32 @@ namespace platform { } void Report::body() { - cout << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; - cout << "============================== ====== ===== === ======= ======= ======= =============== ================= ===============" << endl; + cout << Colors::GREEN() << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; + cout << "============================== ====== ===== === ======= ======= ======= =============== ================== ===============" << endl; json lastResult; totalScore = 0; + bool odd = true; for (const auto& r : data["results"]) { - cout << setw(30) << left << r["dataset"].get() << " "; + auto color = odd ? Colors::CYAN() : Colors::BLUE(); + cout << color << setw(30) << left << r["dataset"].get() << " "; cout << setw(6) << right << r["samples"].get() << " "; cout << setw(5) << right << r["features"].get() << " "; cout << setw(3) << right << r["classes"].get() << " "; cout << setw(7) << setprecision(2) << fixed << r["nodes"].get() << " "; cout << setw(7) << setprecision(2) << fixed << r["leaves"].get() << " "; cout << setw(7) << setprecision(2) << fixed << r["depth"].get() << " "; - cout << setw(8) << right << setprecision(6) << fixed << r["score_test"].get() << "±" << setw(6) << setprecision(4) << fixed << r["score_test_std"].get() << " "; - cout << setw(10) << right << setprecision(6) << fixed << r["test_time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["test_time_std"].get() << " "; - cout << " " << r["hyperparameters"].get(); + cout << setw(8) << right << setprecision(6) << fixed << r["score"].get() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get() << " "; + cout << setw(11) << right << setprecision(6) << fixed << r["time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get() << " "; + try { + cout << r["hyperparameters"].get(); + } + catch (const exception& err) { + cout << r["hyperparameters"]; + } cout << endl; lastResult = r; - totalScore += r["score_test"].get(); + totalScore += r["score"].get(); + odd = !odd; } if (data["results"].size() == 1) { cout << string(MAXL, '*') << endl; @@ -74,12 +82,12 @@ namespace platform { } void Report::footer() { - cout << string(MAXL, '*') << endl; + cout << Colors::MAGENTA() << string(MAXL, '*') << endl; auto score = data["score_name"].get(); if (score == BestResult::scoreName()) { cout << headerLine(score + " compared to " + BestResult::title() + " .: " + to_string(totalScore / BestResult::score())); } - cout << string(MAXL, '*') << endl; + cout << string(MAXL, '*') << endl << Colors::RESET(); } } \ No newline at end of file diff --git a/src/Platform/Report.h b/src/Platform/Report.h index 302ac60..5934b2f 100644 --- a/src/Platform/Report.h +++ b/src/Platform/Report.h @@ -3,6 +3,7 @@ #include #include #include +#include "Colors.h" using json = nlohmann::json; const int MAXL = 121; diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index 48c7e9a..0bf4070 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -3,6 +3,7 @@ #include "Results.h" #include "Report.h" #include "BestResult.h" +#include "Colors.h" namespace platform { Result::Result(const string& path, const string& filename) : path(path) @@ -59,25 +60,35 @@ namespace platform { } void Results::show() const { - cout << "Results found: " << files.size() << endl; + cout << Colors::GREEN() << "Results found: " << files.size() << endl; cout << "-------------------" << endl; auto i = 0; cout << " # Date Model Score Name Score Duration Title" << endl; cout << "=== ========== ============ =========== =========== ========= =============================================================" << endl; + bool odd = true; for (const auto& result : files) { - cout << setw(3) << fixed << right << i++ << " "; + auto color = odd ? Colors::BLUE() : Colors::CYAN(); + cout << color << setw(3) << fixed << right << i++ << " "; cout << result.to_string() << endl; if (i == max && max != 0) { break; } + odd = !odd; } } int Results::getIndex(const string& intent) const { - cout << "Choose result to " << intent << ": "; - int index; - cin >> index; - if (index >= 0 && index < files.size()) { + string color; + if (intent == "delete") { + color = Colors::RED(); + } else { + color = Colors::YELLOW(); + } + cout << color << "Choose result to " << intent << " (cancel=-1): "; + string line; + getline(cin, line); + int index = stoi(line); + if (index >= -1 && index < static_cast(files.size())) { return index; } cout << "Invalid index" << endl; @@ -85,7 +96,7 @@ namespace platform { } void Results::report(const int index) const { - cout << "Reporting " << files.at(index).getFilename() << endl; + cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl; auto data = files.at(index).load(); Report report(data); report.show(); @@ -97,7 +108,7 @@ namespace platform { bool finished = false; string filename, line, options = "qldhsr"; while (!finished) { - cout << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r'): "; + cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r'): "; getline(cin, line); if (line.size() == 0) continue; @@ -131,6 +142,7 @@ namespace platform { cout << "Deleting " << filename << endl; remove((path + "/" + filename).c_str()); files.erase(files.begin() + index); + cout << "File: " + filename + " deleted!" << endl; show(); break; case 'h': @@ -161,7 +173,7 @@ namespace platform { } void Results::sortList() { - cout << "Choose sorting field (date='d', score='s', duration='u', model='m'): "; + cout << Colors::YELLOW() << "Choose sorting field (date='d', score='s', duration='u', model='m'): "; string line; char option; getline(cin, line); -- 2.45.2