From 4370bf51d7e8098cd7e4a8f9c086c29d1b72d905 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 21 Aug 2023 17:14:23 +0200 Subject: [PATCH 1/8] Refactor Report into ReportBase & ReportConsole --- .gitmodules | 3 +++ lib/openXLSX | 1 + 2 files changed, 4 insertions(+) create mode 160000 lib/openXLSX diff --git a/.gitmodules b/.gitmodules index 2989f8a..626d10f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "lib/json"] path = lib/json url = https://github.com/nlohmann/json.git +[submodule "lib/openXLSX"] + path = lib/openXLSX + url = https://github.com/troldal/OpenXLSX.git diff --git a/lib/openXLSX b/lib/openXLSX new file mode 160000 index 0000000..b80da42 --- /dev/null +++ b/lib/openXLSX @@ -0,0 +1 @@ +Subproject commit b80da42d1454f361c29117095ebe1989437db390 -- 2.45.2 From 0f66ac73d0d15540a07e182b6a32d00a79eaa730 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 21 Aug 2023 17:15:14 +0200 Subject: [PATCH 2/8] Revert "Refactor Report into ReportBase & ReportConsole" This reverts commit 4370bf51d7e8098cd7e4a8f9c086c29d1b72d905. --- .gitmodules | 3 --- lib/openXLSX | 1 - 2 files changed, 4 deletions(-) delete mode 160000 lib/openXLSX diff --git a/.gitmodules b/.gitmodules index 626d10f..2989f8a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,6 +10,3 @@ [submodule "lib/json"] path = lib/json url = https://github.com/nlohmann/json.git -[submodule "lib/openXLSX"] - path = lib/openXLSX - url = https://github.com/troldal/OpenXLSX.git diff --git a/lib/openXLSX b/lib/openXLSX deleted file mode 160000 index b80da42..0000000 --- a/lib/openXLSX +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b80da42d1454f361c29117095ebe1989437db390 -- 2.45.2 From 8066701c3cf143774e4f0143c260c6b2e85aa309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 21 Aug 2023 17:16:29 +0200 Subject: [PATCH 3/8] Refactor Report class into ReportBase & ReportCons --- CMakeLists.txt | 1 + lib/openXLSX | 1 + src/Platform/CMakeLists.txt | 6 +-- src/Platform/Experiment.cc | 4 +- src/Platform/Report.h | 26 ---------- src/Platform/ReportBase.cc | 38 +++++++++++++++ src/Platform/ReportBase.h | 25 ++++++++++ src/Platform/{Report.cc => ReportConsole.cc} | 50 ++++---------------- src/Platform/ReportConsole.h | 24 ++++++++++ src/Platform/ReportExcel.h | 16 +++++++ src/Platform/Results.cc | 4 +- 11 files changed, 122 insertions(+), 73 deletions(-) create mode 160000 lib/openXLSX delete mode 100644 src/Platform/Report.h create mode 100644 src/Platform/ReportBase.cc create mode 100644 src/Platform/ReportBase.h rename src/Platform/{Report.cc => ReportConsole.cc} (83%) create mode 100644 src/Platform/ReportConsole.h create mode 100644 src/Platform/ReportExcel.h diff --git a/CMakeLists.txt b/CMakeLists.txt index c53a3a2..a187279 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,7 @@ endif (ENABLE_CLANG_TIDY) add_git_submodule("lib/mdlp") add_git_submodule("lib/argparse") add_git_submodule("lib/json") +add_git_submodule("lib/openXLSX") # Subdirectories # -------------- diff --git a/lib/openXLSX b/lib/openXLSX new file mode 160000 index 0000000..b80da42 --- /dev/null +++ b/lib/openXLSX @@ -0,0 +1 @@ +Subproject commit b80da42d1454f361c29117095ebe1989437db390 diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index 78c6615..6cd1cf0 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -4,9 +4,9 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/Files) include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) -add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc Report.cc) -add_executable(manage manage.cc Results.cc Report.cc) +add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc) +add_executable(manage manage.cc Results.cc ReportConsole.cc ReportBase.cc) add_executable(list list.cc platformUtils Datasets.cc) target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") -target_link_libraries(manage "${TORCH_LIBRARIES}") +target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX) target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 88e3125..480a9de 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -1,7 +1,7 @@ #include "Experiment.h" #include "Datasets.h" #include "Models.h" -#include "Report.h" +#include "ReportConsole.h" namespace platform { using json = nlohmann::json; @@ -91,7 +91,7 @@ namespace platform { void Experiment::report() { json data = build_json(); - Report report(data); + ReportConsole report(data); report.show(); } diff --git a/src/Platform/Report.h b/src/Platform/Report.h deleted file mode 100644 index 105785f..0000000 --- a/src/Platform/Report.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef REPORT_H -#define REPORT_H -#include -#include -#include -#include "Colors.h" - -using json = nlohmann::json; -const int MAXL = 128; -namespace platform { - using namespace std; - class Report { - public: - explicit Report(json data_) { data = data_; }; - virtual ~Report() = default; - void show(); - private: - void header(); - void body(); - void footer(); - string fromVector(const string& key); - json data; - double totalScore; // Total score of all results in a report - }; -}; -#endif \ No newline at end of file diff --git a/src/Platform/ReportBase.cc b/src/Platform/ReportBase.cc new file mode 100644 index 0000000..f22a89f --- /dev/null +++ b/src/Platform/ReportBase.cc @@ -0,0 +1,38 @@ +#include +#include +#include "ReportBase.h" +#include "BestResult.h" + + +namespace platform { + string ReportBase::fromVector(const string& key) + { + stringstream oss; + string sep = ""; + oss << "["; + for (auto& item : data[key]) { + oss << sep << item.get(); + sep = ", "; + } + oss << "]"; + return oss.str(); + } + string ReportBase::fVector(const string& title, const json& data, const int width, const int precision) + { + stringstream oss; + string sep = ""; + oss << title << "["; + for (const auto& item : data) { + oss << sep << fixed << setw(width) << setprecision(precision) << item.get(); + sep = ", "; + } + oss << "]"; + return oss.str(); + } + void ReportBase::show() + { + header(); + body(); + footer(); + } +} \ No newline at end of file diff --git a/src/Platform/ReportBase.h b/src/Platform/ReportBase.h new file mode 100644 index 0000000..2d8a072 --- /dev/null +++ b/src/Platform/ReportBase.h @@ -0,0 +1,25 @@ +#ifndef REPORTBASE_H +#define REPORTBASE_H +#include +#include +#include + +using json = nlohmann::json; +namespace platform { + using namespace std; + class ReportBase { + public: + explicit ReportBase(json data_) { data = data_; }; + virtual ~ReportBase() = default; + void show(); + protected: + json data; + double totalScore; // Total score of all results in a report + string fromVector(const string& key); + string fVector(const string& title, const json& data, const int width, const int precision); + virtual void header() = 0; + virtual void body() = 0; + virtual void footer() = 0; + }; +}; +#endif \ No newline at end of file diff --git a/src/Platform/Report.cc b/src/Platform/ReportConsole.cc similarity index 83% rename from src/Platform/Report.cc rename to src/Platform/ReportConsole.cc index 5690668..52f822d 100644 --- a/src/Platform/Report.cc +++ b/src/Platform/ReportConsole.cc @@ -1,52 +1,23 @@ #include #include -#include "Report.h" +#include "ReportConsole.h" #include "BestResult.h" namespace platform { + struct separated : numpunct { + char do_decimal_point() const { return ','; } + char do_thousands_sep() const { return '.'; } + string do_grouping() const { return "\03"; } + }; string headerLine(const string& text) { int n = MAXL - text.length() - 3; n = n < 0 ? 0 : n; return "* " + text + string(n, ' ') + "*\n"; } - string Report::fromVector(const string& key) - { - stringstream oss; - string sep = ""; - oss << "["; - for (auto& item : data[key]) { - oss << sep << item.get(); - sep = ", "; - } - oss << "]"; - return oss.str(); - } - string fVector(const string& title, const json& data, const int width, const int precision) - { - stringstream oss; - string sep = ""; - oss << title << "["; - for (const auto& item : data) { - oss << sep << fixed << setw(width) << setprecision(precision) << item.get(); - sep = ", "; - } - oss << "]"; - return oss.str(); - } - void Report::show() - { - header(); - body(); - footer(); - } - struct separated : numpunct { - char do_decimal_point() const { return ','; } - char do_thousands_sep() const { return '.'; } - string do_grouping() const { return "\03"; } - }; - void Report::header() + + void ReportConsole::header() { locale mylocale(cout.getloc(), new separated); locale::global(mylocale); @@ -62,7 +33,7 @@ namespace platform { cout << string(MAXL, '*') << endl; cout << endl; } - void Report::body() + void ReportConsole::body() { cout << Colors::GREEN() << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; cout << "============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl; @@ -100,7 +71,7 @@ namespace platform { cout << string(MAXL, '*') << endl; } } - void Report::footer() + void ReportConsole::footer() { cout << Colors::MAGENTA() << string(MAXL, '*') << endl; auto score = data["score_name"].get(); @@ -110,6 +81,5 @@ namespace platform { cout << headerLine(oss.str()); } cout << string(MAXL, '*') << endl << Colors::RESET(); - } } \ No newline at end of file diff --git a/src/Platform/ReportConsole.h b/src/Platform/ReportConsole.h new file mode 100644 index 0000000..d09f5e1 --- /dev/null +++ b/src/Platform/ReportConsole.h @@ -0,0 +1,24 @@ +#ifndef REPORTCONSOLE_H +#define REPORTCONSOLE_H +#include +#include +#include +#include "ReportBase.h" +#include "Colors.h" + +using json = nlohmann::json; +const int MAXL = 128; +namespace platform { + using namespace std; + class ReportConsole : public ReportBase{ + public: + explicit ReportConsole(json data_) : ReportBase(data_) {}; + virtual ~ReportConsole() = default; + private: + + void header() override; + void body() override; + void footer() override; + }; +}; +#endif \ No newline at end of file diff --git a/src/Platform/ReportExcel.h b/src/Platform/ReportExcel.h new file mode 100644 index 0000000..9cb4313 --- /dev/null +++ b/src/Platform/ReportExcel.h @@ -0,0 +1,16 @@ +#ifndef REPORTEXCEL_H +#define REPORTEXCEL_H +#include "ReportBase.h" +namespace platform { + using namespace std; + class ReportExcel : public ReportBase{ + public: + explicit ReportExcel(json data_) : ReportBase(data_) {}; + virtual ~ReportExcel() = default; + private: + void header() override; + void body() override; + void footer() override; + }; +}; +#endif // !REPORTEXCEL_H \ No newline at end of file diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index 818f51e..7c19871 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -1,7 +1,7 @@ #include #include "platformUtils.h" #include "Results.h" -#include "Report.h" +#include "ReportConsole.h" #include "BestResult.h" #include "Colors.h" namespace platform { @@ -98,7 +98,7 @@ namespace platform { { cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl; auto data = files.at(index).load(); - Report report(data); + ReportConsole report(data); report.show(); } void Results::menu() -- 2.45.2 From d2da0ddb88c57d504243b8d6648cd1d909f4172f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 21 Aug 2023 17:51:49 +0200 Subject: [PATCH 4/8] Create ReportExcel eq to ReportConsole --- src/Platform/CMakeLists.txt | 2 +- src/Platform/ReportConsole.cc | 2 +- src/Platform/ReportConsole.h | 6 +-- src/Platform/ReportExcel.cc | 85 +++++++++++++++++++++++++++++++++++ src/Platform/ReportExcel.h | 2 + src/Platform/Results.cc | 36 ++++++++++----- src/Platform/Results.h | 2 +- 7 files changed, 117 insertions(+), 18 deletions(-) create mode 100644 src/Platform/ReportExcel.cc diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index 6cd1cf0..ae295c9 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -5,7 +5,7 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc) -add_executable(manage manage.cc Results.cc ReportConsole.cc ReportBase.cc) +add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc) add_executable(list list.cc platformUtils Datasets.cc) target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX) diff --git a/src/Platform/ReportConsole.cc b/src/Platform/ReportConsole.cc index 52f822d..910b0df 100644 --- a/src/Platform/ReportConsole.cc +++ b/src/Platform/ReportConsole.cc @@ -10,7 +10,7 @@ namespace platform { char do_thousands_sep() const { return '.'; } string do_grouping() const { return "\03"; } }; - string headerLine(const string& text) + string ReportConsole::headerLine(const string& text) { int n = MAXL - text.length() - 3; n = n < 0 ? 0 : n; diff --git a/src/Platform/ReportConsole.h b/src/Platform/ReportConsole.h index d09f5e1..57e0024 100644 --- a/src/Platform/ReportConsole.h +++ b/src/Platform/ReportConsole.h @@ -2,20 +2,18 @@ #define REPORTCONSOLE_H #include #include -#include #include "ReportBase.h" #include "Colors.h" -using json = nlohmann::json; -const int MAXL = 128; namespace platform { using namespace std; + const int MAXL = 128; class ReportConsole : public ReportBase{ public: explicit ReportConsole(json data_) : ReportBase(data_) {}; virtual ~ReportConsole() = default; private: - + string headerLine(const string& text); void header() override; void body() override; void footer() override; diff --git a/src/Platform/ReportExcel.cc b/src/Platform/ReportExcel.cc new file mode 100644 index 0000000..275248a --- /dev/null +++ b/src/Platform/ReportExcel.cc @@ -0,0 +1,85 @@ +#include +#include +#include "ReportExcel.h" +#include "BestResult.h" + + +namespace platform { + struct separated : numpunct { + char do_decimal_point() const { return ','; } + char do_thousands_sep() const { return '.'; } + string do_grouping() const { return "\03"; } + }; + string headerLine(const string& text) + { + int n = MAXLL - text.length() - 3; + n = n < 0 ? 0 : n; + return "* " + text + string(n, ' ') + "*\n"; + } + + void ReportExcel::header() + { + locale mylocale(cout.getloc(), new separated); + locale::global(mylocale); + cout.imbue(mylocale); + stringstream oss; + cout << Colors::MAGENTA() << string(MAXLL, '*') << endl; + cout << headerLine("Report " + data["model"].get() + " ver. " + data["version"].get() + " with " + to_string(data["folds"].get()) + " Folds cross validation and " + to_string(data["seeds"].size()) + " random seeds. " + data["date"].get() + " " + data["time"].get()); + cout << headerLine(data["title"].get()); + cout << headerLine("Random seeds: " + fromVector("seeds") + " Stratified: " + (data["stratified"].get() ? "True" : "False")); + oss << "Execution took " << setprecision(2) << fixed << data["duration"].get() << " seconds, " << data["duration"].get() / 3600 << " hours, on " << data["platform"].get(); + cout << headerLine(oss.str()); + cout << headerLine("Score is " + data["score_name"].get()); + cout << string(MAXLL, '*') << endl; + cout << endl; + } + void ReportExcel::body() + { + cout << Colors::GREEN() << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; + cout << "============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl; + json lastResult; + totalScore = 0; + bool odd = true; + for (const auto& r : data["results"]) { + auto color = odd ? Colors::CYAN() : Colors::BLUE(); + cout << color << setw(30) << left << r["dataset"].get() << " "; + cout << setw(6) << right << r["samples"].get() << " "; + cout << setw(5) << right << r["features"].get() << " "; + cout << setw(3) << right << r["classes"].get() << " "; + cout << setw(9) << setprecision(2) << fixed << r["nodes"].get() << " "; + cout << setw(9) << setprecision(2) << fixed << r["leaves"].get() << " "; + cout << setw(9) << setprecision(2) << fixed << r["depth"].get() << " "; + cout << setw(8) << right << setprecision(6) << fixed << r["score"].get() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get() << " "; + cout << setw(11) << right << setprecision(6) << fixed << r["time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get() << " "; + try { + cout << r["hyperparameters"].get(); + } + catch (const exception& err) { + cout << r["hyperparameters"]; + } + cout << endl; + lastResult = r; + totalScore += r["score"].get(); + odd = !odd; + } + if (data["results"].size() == 1) { + cout << string(MAXLL, '*') << endl; + cout << headerLine(fVector("Train scores: ", lastResult["scores_train"], 14, 12)); + cout << headerLine(fVector("Test scores: ", lastResult["scores_test"], 14, 12)); + cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3)); + cout << headerLine(fVector("Test times: ", lastResult["times_test"], 10, 3)); + cout << string(MAXLL, '*') << endl; + } + } + void ReportExcel::footer() + { + cout << Colors::MAGENTA() << string(MAXLL, '*') << endl; + auto score = data["score_name"].get(); + if (score == BestResult::scoreName()) { + stringstream oss; + oss << score << " compared to " << BestResult::title() << " .: " << totalScore / BestResult::score(); + cout << headerLine(oss.str()); + } + cout << string(MAXLL, '*') << endl << Colors::RESET(); + } +} \ No newline at end of file diff --git a/src/Platform/ReportExcel.h b/src/Platform/ReportExcel.h index 9cb4313..e19dca9 100644 --- a/src/Platform/ReportExcel.h +++ b/src/Platform/ReportExcel.h @@ -1,8 +1,10 @@ #ifndef REPORTEXCEL_H #define REPORTEXCEL_H #include "ReportBase.h" +#include "Colors.h" namespace platform { using namespace std; + const int MAXLL = 128; class ReportExcel : public ReportBase{ public: explicit ReportExcel(json data_) : ReportBase(data_) {}; diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index 7c19871..0440bc1 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -2,6 +2,7 @@ #include "platformUtils.h" #include "Results.h" #include "ReportConsole.h" +#include "ReportExcel.h" #include "BestResult.h" #include "Colors.h" namespace platform { @@ -94,21 +95,26 @@ namespace platform { cout << "Invalid index" << endl; return -1; } - void Results::report(const int index) const + void Results::report(const int index, const bool excelReport) const { cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl; auto data = files.at(index).load(); - ReportConsole report(data); - report.show(); + if (excelReport) { + ReportExcel report(data); + report.show(); + } else { + ReportConsole report(data); + report.show(); + } } void Results::menu() { char option; int index; bool finished = false; - string filename, line, options = "qldhsr"; + string filename, line, options = "qldhsre"; while (!finished) { - cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r'): "; + cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): "; getline(cin, line); if (line.size() == 0) continue; @@ -119,12 +125,14 @@ namespace platform { } option = line[0]; } else { - index = stoi(line); - if (index >= 0 && index < files.size()) { - report(index); - } else { - cout << "Invalid option" << endl; + if (all_of(line.begin(), line.end(), ::isdigit)) { + index = stoi(line); + if (index >= 0 && index < files.size()) { + report(index, false); + continue; + } } + cout << "Invalid option" << endl; continue; } switch (option) { @@ -164,7 +172,13 @@ namespace platform { index = getIndex("report"); if (index == -1) break; - report(index); + report(index, false); + break; + case 'e': + index = getIndex("excel"); + if (index == -1) + break; + report(index, true); break; default: cout << "Invalid option" << endl; diff --git a/src/Platform/Results.h b/src/Platform/Results.h index e6b1552..39f3c55 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -42,7 +42,7 @@ namespace platform { vector files; void load(); // Loads the list of results void show() const; - void report(const int index) const; + void report(const int index, const bool excelReport) const; int getIndex(const string& intent) const; void menu(); void sortList(); -- 2.45.2 From c59dd30e53ed6b801add8d99466f88034ed8f7c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 22 Aug 2023 11:55:15 +0200 Subject: [PATCH 5/8] Complete Excel Report with data --- src/Platform/Paths.h | 1 + src/Platform/ReportBase.cc | 1 - src/Platform/ReportBase.h | 2 - src/Platform/ReportConsole.cc | 7 ++- src/Platform/ReportConsole.h | 2 +- src/Platform/ReportExcel.cc | 112 +++++++++++++++++++++------------- src/Platform/ReportExcel.h | 13 +++- 7 files changed, 85 insertions(+), 53 deletions(-) diff --git a/src/Platform/Paths.h b/src/Platform/Paths.h index fdda25a..b19b09f 100644 --- a/src/Platform/Paths.h +++ b/src/Platform/Paths.h @@ -6,6 +6,7 @@ namespace platform { public: static std::string datasets() { return "datasets/"; } static std::string results() { return "results/"; } + static std::string excel() { return "excel/"; } }; } #endif \ No newline at end of file diff --git a/src/Platform/ReportBase.cc b/src/Platform/ReportBase.cc index f22a89f..24125f8 100644 --- a/src/Platform/ReportBase.cc +++ b/src/Platform/ReportBase.cc @@ -33,6 +33,5 @@ namespace platform { { header(); body(); - footer(); } } \ No newline at end of file diff --git a/src/Platform/ReportBase.h b/src/Platform/ReportBase.h index 2d8a072..2acbbc7 100644 --- a/src/Platform/ReportBase.h +++ b/src/Platform/ReportBase.h @@ -14,12 +14,10 @@ namespace platform { void show(); protected: json data; - double totalScore; // Total score of all results in a report string fromVector(const string& key); string fVector(const string& title, const json& data, const int width, const int precision); virtual void header() = 0; virtual void body() = 0; - virtual void footer() = 0; }; }; #endif \ No newline at end of file diff --git a/src/Platform/ReportConsole.cc b/src/Platform/ReportConsole.cc index 910b0df..2e3ed0c 100644 --- a/src/Platform/ReportConsole.cc +++ b/src/Platform/ReportConsole.cc @@ -10,6 +10,7 @@ namespace platform { char do_thousands_sep() const { return '.'; } string do_grouping() const { return "\03"; } }; + string ReportConsole::headerLine(const string& text) { int n = MAXL - text.length() - 3; @@ -38,7 +39,7 @@ namespace platform { cout << Colors::GREEN() << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; cout << "============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl; json lastResult; - totalScore = 0; + double totalScore = 0.0; bool odd = true; for (const auto& r : data["results"]) { auto color = odd ? Colors::CYAN() : Colors::BLUE(); @@ -69,9 +70,11 @@ namespace platform { cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3)); cout << headerLine(fVector("Test times: ", lastResult["times_test"], 10, 3)); cout << string(MAXL, '*') << endl; + } else { + footer(totalScore); } } - void ReportConsole::footer() + void ReportConsole::footer(double totalScore) { cout << Colors::MAGENTA() << string(MAXL, '*') << endl; auto score = data["score_name"].get(); diff --git a/src/Platform/ReportConsole.h b/src/Platform/ReportConsole.h index 57e0024..5c795b7 100644 --- a/src/Platform/ReportConsole.h +++ b/src/Platform/ReportConsole.h @@ -16,7 +16,7 @@ namespace platform { string headerLine(const string& text); void header() override; void body() override; - void footer() override; + void footer(double totalScore); }; }; #endif \ No newline at end of file diff --git a/src/Platform/ReportExcel.cc b/src/Platform/ReportExcel.cc index 275248a..bb51dfb 100644 --- a/src/Platform/ReportExcel.cc +++ b/src/Platform/ReportExcel.cc @@ -7,79 +7,103 @@ namespace platform { struct separated : numpunct { char do_decimal_point() const { return ','; } + char do_thousands_sep() const { return '.'; } + string do_grouping() const { return "\03"; } }; - string headerLine(const string& text) + + void ReportExcel::createFile() { - int n = MAXLL - text.length() - 3; - n = n < 0 ? 0 : n; - return "* " + text + string(n, ' ') + "*\n"; + doc.create(Paths::excel() + "some_results.xlsx"); + wks = doc.workbook().worksheet("Sheet1"); + wks.setName(data["model"].get()); } - + + void ReportExcel::closeFile() + { + doc.save(); + doc.close(); + } + void ReportExcel::header() { locale mylocale(cout.getloc(), new separated); locale::global(mylocale); cout.imbue(mylocale); stringstream oss; - cout << Colors::MAGENTA() << string(MAXLL, '*') << endl; - cout << headerLine("Report " + data["model"].get() + " ver. " + data["version"].get() + " with " + to_string(data["folds"].get()) + " Folds cross validation and " + to_string(data["seeds"].size()) + " random seeds. " + data["date"].get() + " " + data["time"].get()); - cout << headerLine(data["title"].get()); - cout << headerLine("Random seeds: " + fromVector("seeds") + " Stratified: " + (data["stratified"].get() ? "True" : "False")); - oss << "Execution took " << setprecision(2) << fixed << data["duration"].get() << " seconds, " << data["duration"].get() / 3600 << " hours, on " << data["platform"].get(); - cout << headerLine(oss.str()); - cout << headerLine("Score is " + data["score_name"].get()); - cout << string(MAXLL, '*') << endl; - cout << endl; + wks.cell("A1").value().set( + "Report " + data["model"].get() + " ver. " + data["version"].get() + " with " + + to_string(data["folds"].get()) + " Folds cross validation and " + to_string(data["seeds"].size()) + + " random seeds. " + data["date"].get() + " " + data["time"].get()); + wks.cell("A2").value() = data["title"].get(); + wks.cell("A3").value() = "Random seeds: " + fromVector("seeds") + " Stratified: " + + (data["stratified"].get() ? "True" : "False"); + oss << "Execution took " << setprecision(2) << fixed << data["duration"].get() << " seconds, " + << data["duration"].get() / 3600 << " hours, on " << data["platform"].get(); + wks.cell("A4").value() = oss.str(); + wks.cell("A5").value() = "Score is " + data["score_name"].get(); } + void ReportExcel::body() { - cout << Colors::GREEN() << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; - cout << "============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl; + auto header = vector( + { "Dataset", "Samples", "Features", "Classes", "Nodes", "Edges", "States", "Score", "Score Std.", "Time", + "Time Std.", "Hyperparameters" }); + int col = 1; + for (const auto& item : header) { + wks.cell(8, col++).value() = item; + } + int row = 9; + col = 1; json lastResult; - totalScore = 0; - bool odd = true; + double totalScore = 0.0; + string hyperparameters; for (const auto& r : data["results"]) { - auto color = odd ? Colors::CYAN() : Colors::BLUE(); - cout << color << setw(30) << left << r["dataset"].get() << " "; - cout << setw(6) << right << r["samples"].get() << " "; - cout << setw(5) << right << r["features"].get() << " "; - cout << setw(3) << right << r["classes"].get() << " "; - cout << setw(9) << setprecision(2) << fixed << r["nodes"].get() << " "; - cout << setw(9) << setprecision(2) << fixed << r["leaves"].get() << " "; - cout << setw(9) << setprecision(2) << fixed << r["depth"].get() << " "; - cout << setw(8) << right << setprecision(6) << fixed << r["score"].get() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get() << " "; - cout << setw(11) << right << setprecision(6) << fixed << r["time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get() << " "; + wks.cell(row, col).value() = r["dataset"].get(); + wks.cell(row, col + 1).value() = r["samples"].get(); + wks.cell(row, col + 2).value() = r["features"].get(); + wks.cell(row, col + 3).value() = r["classes"].get(); + wks.cell(row, col + 4).value() = r["nodes"].get(); + wks.cell(row, col + 5).value() = r["leaves"].get(); + wks.cell(row, col + 6).value() = r["depth"].get(); + wks.cell(row, col + 7).value() = r["score"].get(); + wks.cell(row, col + 8).value() = r["score_std"].get(); + wks.cell(row, col + 9).value() = r["time"].get(); + wks.cell(row, col + 10).value() = r["time_std"].get(); try { - cout << r["hyperparameters"].get(); + hyperparameters = r["hyperparameters"].get(); } catch (const exception& err) { - cout << r["hyperparameters"]; + stringstream oss; + oss << r["hyperparameters"]; + hyperparameters = oss.str(); } - cout << endl; + wks.cell(row, col + 11).value() = hyperparameters; lastResult = r; totalScore += r["score"].get(); - odd = !odd; + row++; } if (data["results"].size() == 1) { - cout << string(MAXLL, '*') << endl; - cout << headerLine(fVector("Train scores: ", lastResult["scores_train"], 14, 12)); - cout << headerLine(fVector("Test scores: ", lastResult["scores_test"], 14, 12)); - cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3)); - cout << headerLine(fVector("Test times: ", lastResult["times_test"], 10, 3)); - cout << string(MAXLL, '*') << endl; + for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) { + row++; + col = 1; + wks.cell(row, col).value() = group; + for (double item : lastResult[group]) { + wks.cell(row, ++col).value() = item; + } + } + } else { + footer(totalScore, row); } } - void ReportExcel::footer() + + void ReportExcel::footer(double totalScore, int row) { - cout << Colors::MAGENTA() << string(MAXLL, '*') << endl; auto score = data["score_name"].get(); if (score == BestResult::scoreName()) { - stringstream oss; - oss << score << " compared to " << BestResult::title() << " .: " << totalScore / BestResult::score(); - cout << headerLine(oss.str()); + wks.cell(row + 2, 1).value() = score + " compared to " + BestResult::title() + " .: "; + wks.cell(row + 2, 5).value() = totalScore / BestResult::score(); } - cout << string(MAXLL, '*') << endl << Colors::RESET(); } } \ No newline at end of file diff --git a/src/Platform/ReportExcel.h b/src/Platform/ReportExcel.h index e19dca9..3700681 100644 --- a/src/Platform/ReportExcel.h +++ b/src/Platform/ReportExcel.h @@ -1,18 +1,25 @@ #ifndef REPORTEXCEL_H #define REPORTEXCEL_H +#include #include "ReportBase.h" +#include "Paths.h" #include "Colors.h" namespace platform { using namespace std; + using namespace OpenXLSX; const int MAXLL = 128; class ReportExcel : public ReportBase{ public: - explicit ReportExcel(json data_) : ReportBase(data_) {}; - virtual ~ReportExcel() = default; + explicit ReportExcel(json data_) : ReportBase(data_) {createFile();}; + virtual ~ReportExcel() {closeFile();}; private: + void createFile(); + void closeFile(); + XLDocument doc; + XLWorksheet wks; void header() override; void body() override; - void footer() override; + void footer(double totalScore, int row); }; }; #endif // !REPORTEXCEL_H \ No newline at end of file -- 2.45.2 From 35432b62945396626f38575391a83ddc387ea33a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 22 Aug 2023 12:30:27 +0200 Subject: [PATCH 6/8] Fix time std was not saved in experiment --- src/Platform/Experiment.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Platform/Experiment.cc b/src/Platform/Experiment.cc index 480a9de..09674ab 100644 --- a/src/Platform/Experiment.cc +++ b/src/Platform/Experiment.cc @@ -181,11 +181,11 @@ namespace platform { item++; } cout << "end. " << flush; - delete fold; } result.setScoreTest(torch::mean(accuracy_test).item()).setScoreTrain(torch::mean(accuracy_train).item()); result.setScoreTestStd(torch::std(accuracy_test).item()).setScoreTrainStd(torch::std(accuracy_train).item()); result.setTrainTime(torch::mean(train_time).item()).setTestTime(torch::mean(test_time).item()); + result.setTestTimeStd(torch::std(test_time).item()).setTrainTimeStd(torch::std(train_time).item()); result.setNodes(torch::mean(nodes).item()).setLeaves(torch::mean(edges).item()).setDepth(torch::mean(num_states).item()); result.setDataset(fileName); addResult(result); -- 2.45.2 From 1c1385b7685e012a40323d506f09a21eb079f64e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 22 Aug 2023 21:55:17 +0200 Subject: [PATCH 7/8] Fix maxModels mistake in BoostAODE if !repeatSp Throw exception if wrong hyperparmeter is supplied --- src/BayesNet/BoostAODE.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index 80fd20c..fa6dabb 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -10,6 +10,13 @@ namespace bayesnet { } void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) { + // Check if hyperparameters are valid + auto validKeys = { "repeatSparent", "maxModels", "ascending" }; + for (const auto& item : hyperparameters.items()) { + if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) { + throw invalid_argument("Hyperparameter " + item.key() + " is not valid"); + } + } if (hyperparameters.contains("repeatSparent")) { repeatSparent = hyperparameters["repeatSparent"]; } @@ -74,7 +81,7 @@ namespace bayesnet { // Step 3.4: Store classifier and its accuracy to weigh its future vote models.push_back(std::move(model)); significanceModels.push_back(significance); - exitCondition = n_models == maxModels; + exitCondition = n_models == maxModels && repeatSparent; } if (featuresUsed.size() != features.size()) { cout << "Warning: BoostAODE did not use all the features" << endl; -- 2.45.2 From 97ca8ac0849cd781497673e2ed66b0ea5936b224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 22 Aug 2023 22:12:20 +0200 Subject: [PATCH 8/8] Move check valid hyperparameters to Classifier --- src/BayesNet/BoostAODE.cc | 8 ++------ src/BayesNet/Classifier.cc | 8 ++++++++ src/BayesNet/Classifier.h | 1 + 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/BayesNet/BoostAODE.cc b/src/BayesNet/BoostAODE.cc index fa6dabb..a7d5e5c 100644 --- a/src/BayesNet/BoostAODE.cc +++ b/src/BayesNet/BoostAODE.cc @@ -11,12 +11,8 @@ namespace bayesnet { void BoostAODE::setHyperparameters(nlohmann::json& hyperparameters) { // Check if hyperparameters are valid - auto validKeys = { "repeatSparent", "maxModels", "ascending" }; - for (const auto& item : hyperparameters.items()) { - if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) { - throw invalid_argument("Hyperparameter " + item.key() + " is not valid"); - } - } + const vector validKeys = { "repeatSparent", "maxModels", "ascending" }; + checkHyperparameters(validKeys, hyperparameters); if (hyperparameters.contains("repeatSparent")) { repeatSparent = hyperparameters["repeatSparent"]; } diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index ff25657..db4a63f 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -152,4 +152,12 @@ namespace bayesnet { { model.dump_cpt(); } + void Classifier::checkHyperparameters(const vector& validKeys, nlohmann::json& hyperparameters) + { + for (const auto& item : hyperparameters.items()) { + if (find(validKeys.begin(), validKeys.end(), item.key()) == validKeys.end()) { + throw invalid_argument("Hyperparameter " + item.key() + " is not valid"); + } + } + } } \ No newline at end of file diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index 0c2940b..d27e486 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -24,6 +24,7 @@ namespace bayesnet { void checkFitParameters(); virtual void buildModel(const torch::Tensor& weights) = 0; void trainModel(const torch::Tensor& weights) override; + void checkHyperparameters(const vector& validKeys, nlohmann::json& hyperparameters); public: Classifier(Network model); virtual ~Classifier() = default; -- 2.45.2