diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index 6cd1cf0..ae295c9 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -5,7 +5,7 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc) -add_executable(manage manage.cc Results.cc ReportConsole.cc ReportBase.cc) +add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc) add_executable(list list.cc platformUtils Datasets.cc) target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX) diff --git a/src/Platform/ReportConsole.cc b/src/Platform/ReportConsole.cc index 52f822d..910b0df 100644 --- a/src/Platform/ReportConsole.cc +++ b/src/Platform/ReportConsole.cc @@ -10,7 +10,7 @@ namespace platform { char do_thousands_sep() const { return '.'; } string do_grouping() const { return "\03"; } }; - string headerLine(const string& text) + string ReportConsole::headerLine(const string& text) { int n = MAXL - text.length() - 3; n = n < 0 ? 0 : n; diff --git a/src/Platform/ReportConsole.h b/src/Platform/ReportConsole.h index d09f5e1..57e0024 100644 --- a/src/Platform/ReportConsole.h +++ b/src/Platform/ReportConsole.h @@ -2,20 +2,18 @@ #define REPORTCONSOLE_H #include #include -#include #include "ReportBase.h" #include "Colors.h" -using json = nlohmann::json; -const int MAXL = 128; namespace platform { using namespace std; + const int MAXL = 128; class ReportConsole : public ReportBase{ public: explicit ReportConsole(json data_) : ReportBase(data_) {}; virtual ~ReportConsole() = default; private: - + string headerLine(const string& text); void header() override; void body() override; void footer() override; diff --git a/src/Platform/ReportExcel.cc b/src/Platform/ReportExcel.cc new file mode 100644 index 0000000..275248a --- /dev/null +++ b/src/Platform/ReportExcel.cc @@ -0,0 +1,85 @@ +#include +#include +#include "ReportExcel.h" +#include "BestResult.h" + + +namespace platform { + struct separated : numpunct { + char do_decimal_point() const { return ','; } + char do_thousands_sep() const { return '.'; } + string do_grouping() const { return "\03"; } + }; + string headerLine(const string& text) + { + int n = MAXLL - text.length() - 3; + n = n < 0 ? 0 : n; + return "* " + text + string(n, ' ') + "*\n"; + } + + void ReportExcel::header() + { + locale mylocale(cout.getloc(), new separated); + locale::global(mylocale); + cout.imbue(mylocale); + stringstream oss; + cout << Colors::MAGENTA() << string(MAXLL, '*') << endl; + cout << headerLine("Report " + data["model"].get() + " ver. " + data["version"].get() + " with " + to_string(data["folds"].get()) + " Folds cross validation and " + to_string(data["seeds"].size()) + " random seeds. " + data["date"].get() + " " + data["time"].get()); + cout << headerLine(data["title"].get()); + cout << headerLine("Random seeds: " + fromVector("seeds") + " Stratified: " + (data["stratified"].get() ? "True" : "False")); + oss << "Execution took " << setprecision(2) << fixed << data["duration"].get() << " seconds, " << data["duration"].get() / 3600 << " hours, on " << data["platform"].get(); + cout << headerLine(oss.str()); + cout << headerLine("Score is " + data["score_name"].get()); + cout << string(MAXLL, '*') << endl; + cout << endl; + } + void ReportExcel::body() + { + cout << Colors::GREEN() << "Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; + cout << "============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl; + json lastResult; + totalScore = 0; + bool odd = true; + for (const auto& r : data["results"]) { + auto color = odd ? Colors::CYAN() : Colors::BLUE(); + cout << color << setw(30) << left << r["dataset"].get() << " "; + cout << setw(6) << right << r["samples"].get() << " "; + cout << setw(5) << right << r["features"].get() << " "; + cout << setw(3) << right << r["classes"].get() << " "; + cout << setw(9) << setprecision(2) << fixed << r["nodes"].get() << " "; + cout << setw(9) << setprecision(2) << fixed << r["leaves"].get() << " "; + cout << setw(9) << setprecision(2) << fixed << r["depth"].get() << " "; + cout << setw(8) << right << setprecision(6) << fixed << r["score"].get() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get() << " "; + cout << setw(11) << right << setprecision(6) << fixed << r["time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get() << " "; + try { + cout << r["hyperparameters"].get(); + } + catch (const exception& err) { + cout << r["hyperparameters"]; + } + cout << endl; + lastResult = r; + totalScore += r["score"].get(); + odd = !odd; + } + if (data["results"].size() == 1) { + cout << string(MAXLL, '*') << endl; + cout << headerLine(fVector("Train scores: ", lastResult["scores_train"], 14, 12)); + cout << headerLine(fVector("Test scores: ", lastResult["scores_test"], 14, 12)); + cout << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3)); + cout << headerLine(fVector("Test times: ", lastResult["times_test"], 10, 3)); + cout << string(MAXLL, '*') << endl; + } + } + void ReportExcel::footer() + { + cout << Colors::MAGENTA() << string(MAXLL, '*') << endl; + auto score = data["score_name"].get(); + if (score == BestResult::scoreName()) { + stringstream oss; + oss << score << " compared to " << BestResult::title() << " .: " << totalScore / BestResult::score(); + cout << headerLine(oss.str()); + } + cout << string(MAXLL, '*') << endl << Colors::RESET(); + } +} \ No newline at end of file diff --git a/src/Platform/ReportExcel.h b/src/Platform/ReportExcel.h index 9cb4313..e19dca9 100644 --- a/src/Platform/ReportExcel.h +++ b/src/Platform/ReportExcel.h @@ -1,8 +1,10 @@ #ifndef REPORTEXCEL_H #define REPORTEXCEL_H #include "ReportBase.h" +#include "Colors.h" namespace platform { using namespace std; + const int MAXLL = 128; class ReportExcel : public ReportBase{ public: explicit ReportExcel(json data_) : ReportBase(data_) {}; diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index 7c19871..0440bc1 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -2,6 +2,7 @@ #include "platformUtils.h" #include "Results.h" #include "ReportConsole.h" +#include "ReportExcel.h" #include "BestResult.h" #include "Colors.h" namespace platform { @@ -94,21 +95,26 @@ namespace platform { cout << "Invalid index" << endl; return -1; } - void Results::report(const int index) const + void Results::report(const int index, const bool excelReport) const { cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl; auto data = files.at(index).load(); - ReportConsole report(data); - report.show(); + if (excelReport) { + ReportExcel report(data); + report.show(); + } else { + ReportConsole report(data); + report.show(); + } } void Results::menu() { char option; int index; bool finished = false; - string filename, line, options = "qldhsr"; + string filename, line, options = "qldhsre"; while (!finished) { - cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r'): "; + cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): "; getline(cin, line); if (line.size() == 0) continue; @@ -119,12 +125,14 @@ namespace platform { } option = line[0]; } else { - index = stoi(line); - if (index >= 0 && index < files.size()) { - report(index); - } else { - cout << "Invalid option" << endl; + if (all_of(line.begin(), line.end(), ::isdigit)) { + index = stoi(line); + if (index >= 0 && index < files.size()) { + report(index, false); + continue; + } } + cout << "Invalid option" << endl; continue; } switch (option) { @@ -164,7 +172,13 @@ namespace platform { index = getIndex("report"); if (index == -1) break; - report(index); + report(index, false); + break; + case 'e': + index = getIndex("excel"); + if (index == -1) + break; + report(index, true); break; default: cout << "Invalid option" << endl; diff --git a/src/Platform/Results.h b/src/Platform/Results.h index e6b1552..39f3c55 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -42,7 +42,7 @@ namespace platform { vector files; void load(); // Loads the list of results void show() const; - void report(const int index) const; + void report(const int index, const bool excelReport) const; int getIndex(const string& intent) const; void menu(); void sortList();