diff --git a/.gitignore b/.gitignore index 8855507..dea436c 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ build/ cmake-build*/** .idea puml/** +.vscode/settings.json diff --git a/.gitmodules b/.gitmodules index 626d10f..dbb94fc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,6 +10,6 @@ [submodule "lib/json"] path = lib/json url = https://github.com/nlohmann/json.git -[submodule "lib/openXLSX"] - path = lib/openXLSX - url = https://github.com/troldal/OpenXLSX.git +[submodule "lib/libxlsxwriter"] + path = lib/libxlsxwriter + url = https://github.com/jmcnamara/libxlsxwriter.git diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index d7af13f..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,109 +0,0 @@ -{ - "files.associations": { - "*.rmd": "markdown", - "*.py": "python", - "vector": "cpp", - "__bit_reference": "cpp", - "__bits": "cpp", - "__config": "cpp", - "__debug": "cpp", - "__errc": "cpp", - "__hash_table": "cpp", - "__locale": "cpp", - "__mutex_base": "cpp", - "__node_handle": "cpp", - "__nullptr": "cpp", - "__split_buffer": "cpp", - "__string": "cpp", - "__threading_support": "cpp", - "__tuple": "cpp", - "array": "cpp", - "atomic": "cpp", - "bitset": "cpp", - "cctype": "cpp", - "chrono": "cpp", - "clocale": "cpp", - "cmath": "cpp", - "compare": "cpp", - "complex": "cpp", - "concepts": "cpp", - "cstdarg": "cpp", - "cstddef": "cpp", - "cstdint": "cpp", - "cstdio": "cpp", - "cstdlib": "cpp", - "cstring": "cpp", - "ctime": "cpp", - "cwchar": "cpp", - "cwctype": "cpp", - "exception": "cpp", - "initializer_list": "cpp", - "ios": "cpp", - "iosfwd": "cpp", - "istream": "cpp", - "limits": "cpp", - "locale": "cpp", - "memory": "cpp", - "mutex": "cpp", - "new": "cpp", - "optional": "cpp", - "ostream": "cpp", - "ratio": "cpp", - "sstream": "cpp", - "stdexcept": "cpp", - "streambuf": "cpp", - "string": "cpp", - "string_view": "cpp", - "system_error": "cpp", - "tuple": "cpp", - "type_traits": "cpp", - "typeinfo": "cpp", - "unordered_map": "cpp", - "variant": "cpp", - "algorithm": "cpp", - "iostream": "cpp", - "iomanip": "cpp", - "numeric": "cpp", - "set": "cpp", - "__tree": "cpp", - "deque": "cpp", - "list": "cpp", - "map": "cpp", - "unordered_set": "cpp", - "any": "cpp", - "condition_variable": "cpp", - "forward_list": "cpp", - "fstream": "cpp", - "stack": "cpp", - "thread": "cpp", - "__memory": "cpp", - "filesystem": "cpp", - "*.toml": "toml", - "utility": "cpp", - "__verbose_abort": "cpp", - "bit": "cpp", - "random": "cpp", - "*.tcc": "cpp", - "functional": "cpp", - "iterator": "cpp", - "memory_resource": "cpp", - "format": "cpp", - "valarray": "cpp", - "regex": "cpp", - "span": "cpp", - "cfenv": "cpp", - "cinttypes": "cpp", - "csetjmp": "cpp", - "future": "cpp", - "queue": "cpp", - "typeindex": "cpp", - "shared_mutex": "cpp", - "*.ipp": "cpp", - "cassert": "cpp", - "charconv": "cpp", - "source_location": "cpp", - "ranges": "cpp" - }, - "cmake.configureOnOpen": false, - "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools" -} \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index be56d34..83feb88 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,7 +54,6 @@ endif (ENABLE_CLANG_TIDY) add_git_submodule("lib/mdlp") add_git_submodule("lib/argparse") add_git_submodule("lib/json") -add_git_submodule("lib/openXLSX") # Subdirectories # -------------- diff --git a/README.md b/README.md index c95d141..266bb8a 100644 --- a/README.md +++ b/README.md @@ -2,4 +2,36 @@ Bayesian Network Classifier with libtorch from scratch +## 0. Setup + +### libxlswriter + +Before compiling BayesNet. + +```bash +cd lib/libxlsxwriter +make +sudo make install +``` + +It has to be installed in /usr/local/lib otherwise CMakeLists.txt has to be modified accordingly + +Environment variable has to be set: + +```bash + export LD_LIBRARY_PATH=/usr/local/lib + ``` + +### Release + +```bash +make release +``` + +### Debug & Tests + +```bash +make debug +``` + ## 1. Introduction diff --git a/lib/catch2 b/lib/catch2 index 4acc518..9c541ca 160000 --- a/lib/catch2 +++ b/lib/catch2 @@ -1 +1 @@ -Subproject commit 4acc51828f7f93f3b2058a63f54d112af4034503 +Subproject commit 9c541ca72e7857dec71d8a41b97e42c2f1c92602 diff --git a/lib/libxlsxwriter b/lib/libxlsxwriter new file mode 160000 index 0000000..44e72c5 --- /dev/null +++ b/lib/libxlsxwriter @@ -0,0 +1 @@ +Subproject commit 44e72c5862f9d549453a4ff6e8ceab0da19705e5 diff --git a/lib/openXLSX b/lib/openXLSX deleted file mode 160000 index b80da42..0000000 --- a/lib/openXLSX +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b80da42d1454f361c29117095ebe1989437db390 diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index b5dc5b9..2a506b8 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -5,12 +5,12 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc) -add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc) +add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc platformUtils.cc) add_executable(list list.cc platformUtils Datasets.cc) target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux") - target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX stdc++fs) + target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp stdc++fs) else() - target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX) + target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp) endif() target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Platform/ReportBase.cc b/src/Platform/ReportBase.cc index 24125f8..6a5b885 100644 --- a/src/Platform/ReportBase.cc +++ b/src/Platform/ReportBase.cc @@ -1,10 +1,22 @@ #include #include +#include "Datasets.h" #include "ReportBase.h" #include "BestResult.h" namespace platform { + ReportBase::ReportBase(json data_, bool compare) : data(data_), compare(compare), margin(0.1) + { + stringstream oss; + oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%"; + meaning = { + {Symbols::equal_best, "Equal to best"}, + {Symbols::better_best, "Better than best"}, + {Symbols::cross, "Less than or equal to ZeroR"}, + {Symbols::upward_arrow, oss.str()} + }; + } string ReportBase::fromVector(const string& key) { stringstream oss; @@ -34,4 +46,62 @@ namespace platform { header(); body(); } + string ReportBase::compareResult(const string& dataset, double result) + { + string status = " "; + if (compare) { + double best = bestResult(dataset, data["model"].get()); + if (result == best) { + status = Symbols::equal_best; + } else if (result > best) { + status = Symbols::better_best; + } + } else { + if (data["score_name"].get() == "accuracy") { + auto dt = Datasets(Paths::datasets(), false); + dt.loadDataset(dataset); + auto numClasses = dt.getNClasses(dataset); + if (numClasses == 2) { + vector distribution = dt.getClassesCounts(dataset); + double nSamples = dt.getNSamples(dataset); + vector::iterator maxValue = max_element(distribution.begin(), distribution.end()); + double mark = *maxValue / nSamples * (1 + margin); + if (mark > 1) { + mark = 0.9995; + } + status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "="; + } + } + } + if (status != " ") { + auto item = summary.find(status); + if (item != summary.end()) { + summary[status]++; + } else { + summary[status] = 1; + } + } + return status; + } + double ReportBase::bestResult(const string& dataset, const string& model) + { + double value = 0.0; + if (bestResults.size() == 0) { + // try to load the best results + string score = data["score_name"]; + replace(score.begin(), score.end(), '_', '-'); + string fileName = "best_results_" + score + "_" + model + ".json"; + ifstream resultData(Paths::results() + "/" + fileName); + if (resultData.is_open()) { + bestResults = json::parse(resultData); + } + } + try { + value = bestResults.at(dataset).at(0); + } + catch (exception) { + value = 1.0; + } + return value; + } } \ No newline at end of file diff --git a/src/Platform/ReportBase.h b/src/Platform/ReportBase.h index 2acbbc7..7695102 100644 --- a/src/Platform/ReportBase.h +++ b/src/Platform/ReportBase.h @@ -2,14 +2,26 @@ #define REPORTBASE_H #include #include +#include "Paths.h" #include using json = nlohmann::json; namespace platform { using namespace std; + class Symbols { + public: + inline static const string check_mark{ "\u2714" }; + inline static const string exclamation{ "\u2757" }; + inline static const string black_star{ "\u2605" }; + inline static const string cross{ "\u2717" }; + inline static const string upward_arrow{ "\u27B6" }; + inline static const string down_arrow{ "\u27B4" }; + inline static const string equal_best{ check_mark }; + inline static const string better_best{ black_star }; + }; class ReportBase { public: - explicit ReportBase(json data_) { data = data_; }; + explicit ReportBase(json data_, bool compare); virtual ~ReportBase() = default; void show(); protected: @@ -18,6 +30,15 @@ namespace platform { string fVector(const string& title, const json& data, const int width, const int precision); virtual void header() = 0; virtual void body() = 0; + virtual void showSummary() = 0; + string compareResult(const string& dataset, double result); + map summary; + double margin; + map meaning; + private: + double bestResult(const string& dataset, const string& model); + bool compare; + json bestResults; }; }; #endif \ No newline at end of file diff --git a/src/Platform/ReportConsole.cc b/src/Platform/ReportConsole.cc index acbb602..0de1c11 100644 --- a/src/Platform/ReportConsole.cc +++ b/src/Platform/ReportConsole.cc @@ -11,11 +11,11 @@ namespace platform { string do_grouping() const { return "\03"; } }; - string ReportConsole::headerLine(const string& text) + string ReportConsole::headerLine(const string& text, int utf = 0) { int n = MAXL - text.length() - 3; n = n < 0 ? 0 : n; - return "* " + text + string(n, ' ') + "*\n"; + return "* " + text + string(n + utf, ' ') + "*\n"; } void ReportConsole::header() @@ -36,8 +36,8 @@ namespace platform { } void ReportConsole::body() { - cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; - cout << "=== ============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl; + cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; + cout << "=== ========================= ====== ===== === ========= ========= ========= =============== =================== ====================" << endl; json lastResult; double totalScore = 0.0; bool odd = true; @@ -50,15 +50,17 @@ namespace platform { auto color = odd ? Colors::CYAN() : Colors::BLUE(); cout << color; cout << setw(3) << index++ << " "; - cout << setw(30) << left << r["dataset"].get() << " "; + cout << setw(25) << left << r["dataset"].get() << " "; cout << setw(6) << right << r["samples"].get() << " "; cout << setw(5) << right << r["features"].get() << " "; cout << setw(3) << right << r["classes"].get() << " "; cout << setw(9) << setprecision(2) << fixed << r["nodes"].get() << " "; cout << setw(9) << setprecision(2) << fixed << r["leaves"].get() << " "; cout << setw(9) << setprecision(2) << fixed << r["depth"].get() << " "; - cout << setw(8) << right << setprecision(6) << fixed << r["score"].get() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get() << " "; - cout << setw(11) << right << setprecision(6) << fixed << r["time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get() << " "; + cout << setw(8) << right << setprecision(6) << fixed << r["score"].get() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get(); + const string status = compareResult(r["dataset"].get(), r["score"].get()); + cout << status; + cout << setw(12) << right << setprecision(6) << fixed << r["time"].get() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get() << " "; try { cout << r["hyperparameters"].get(); } @@ -81,9 +83,21 @@ namespace platform { footer(totalScore); } } + void ReportConsole::showSummary() + { + for (const auto& item : summary) { + stringstream oss; + oss << setw(3) << left << item.first; + oss << setw(3) << right << item.second << " "; + oss << left << meaning.at(item.first); + cout << headerLine(oss.str(), 2); + } + } + void ReportConsole::footer(double totalScore) { cout << Colors::MAGENTA() << string(MAXL, '*') << endl; + showSummary(); auto score = data["score_name"].get(); if (score == BestResult::scoreName()) { stringstream oss; diff --git a/src/Platform/ReportConsole.h b/src/Platform/ReportConsole.h index b34e71f..3dcc719 100644 --- a/src/Platform/ReportConsole.h +++ b/src/Platform/ReportConsole.h @@ -7,17 +7,18 @@ namespace platform { using namespace std; - const int MAXL = 132; + const int MAXL = 133; class ReportConsole : public ReportBase { public: - explicit ReportConsole(json data_, int index = -1) : ReportBase(data_), selectedIndex(index) {}; + explicit ReportConsole(json data_, bool compare = false, int index = -1) : ReportBase(data_, compare), selectedIndex(index) {}; virtual ~ReportConsole() = default; private: int selectedIndex; - string headerLine(const string& text); + string headerLine(const string& text, int utf); void header() override; void body() override; void footer(double totalScore); + void showSummary(); }; }; #endif \ No newline at end of file diff --git a/src/Platform/ReportExcel.cc b/src/Platform/ReportExcel.cc index 0ed5854..41891e9 100644 --- a/src/Platform/ReportExcel.cc +++ b/src/Platform/ReportExcel.cc @@ -13,17 +13,195 @@ namespace platform { string do_grouping() const { return "\03"; } }; + ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook) : ReportBase(data_, compare), row(0), workbook(workbook) + { + normalSize = 14; //font size for report body + colorTitle = 0xB1A0C7; + colorOdd = 0xDCE6F1; + colorEven = 0xFDE9D9; + createFile(); + } + + lxw_workbook* ReportExcel::getWorkbook() + { + return workbook; + } + + lxw_format* ReportExcel::efectiveStyle(const string& style) + { + lxw_format* efectiveStyle; + if (style == "") { + efectiveStyle = NULL; + } else { + string suffix = row % 2 ? "_odd" : "_even"; + efectiveStyle = styles.at(style + suffix); + } + return efectiveStyle; + } + + void ReportExcel::writeString(int row, int col, const string& text, const string& style) + { + worksheet_write_string(worksheet, row, col, text.c_str(), efectiveStyle(style)); + } + void ReportExcel::writeInt(int row, int col, const int number, const string& style) + { + worksheet_write_number(worksheet, row, col, number, efectiveStyle(style)); + } + void ReportExcel::writeDouble(int row, int col, const double number, const string& style) + { + worksheet_write_number(worksheet, row, col, number, efectiveStyle(style)); + } + + void ReportExcel::formatColumns() + { + worksheet_freeze_panes(worksheet, 6, 1); + vector columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 15, 12, 23 }; + for (int i = 0; i < columns_sizes.size(); ++i) { + worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL); + } + } + + void ReportExcel::addColor(lxw_format* style, bool odd) + { + uint32_t efectiveColor = odd ? colorEven : colorOdd; + format_set_bg_color(style, lxw_color_t(efectiveColor)); + } + void ReportExcel::createStyle(const string& name, lxw_format* style, bool odd) + { + addColor(style, odd); + if (name == "textCentered") { + format_set_align(style, LXW_ALIGN_CENTER); + format_set_font_size(style, normalSize); + format_set_border(style, LXW_BORDER_THIN); + } else if (name == "text") { + format_set_font_size(style, normalSize); + format_set_border(style, LXW_BORDER_THIN); + } else if (name == "bodyHeader") { + format_set_bold(style); + format_set_font_size(style, normalSize); + format_set_align(style, LXW_ALIGN_CENTER); + format_set_align(style, LXW_ALIGN_VERTICAL_CENTER); + format_set_border(style, LXW_BORDER_THIN); + format_set_bg_color(style, lxw_color_t(colorTitle)); + } else if (name == "result") { + format_set_font_size(style, normalSize); + format_set_border(style, LXW_BORDER_THIN); + format_set_num_format(style, "0.0000000"); + } else if (name == "time") { + format_set_font_size(style, normalSize); + format_set_border(style, LXW_BORDER_THIN); + format_set_num_format(style, "#,##0.000000"); + } else if (name == "ints") { + format_set_font_size(style, normalSize); + format_set_num_format(style, "###,##0"); + format_set_border(style, LXW_BORDER_THIN); + } else if (name == "floats") { + format_set_border(style, LXW_BORDER_THIN); + format_set_font_size(style, normalSize); + format_set_num_format(style, "#,##0.00"); + } + } + + void ReportExcel::createFormats() + { + auto styleNames = { "text", "textCentered", "bodyHeader", "result", "time", "ints", "floats" }; + lxw_format* style; + for (string name : styleNames) { + lxw_format* style = workbook_add_format(workbook); + style = workbook_add_format(workbook); + createStyle(name, style, true); + styles[name + "_odd"] = style; + style = workbook_add_format(workbook); + createStyle(name, style, false); + styles[name + "_even"] = style; + } + + // Header 1st line + lxw_format* headerFirst = workbook_add_format(workbook); + format_set_bold(headerFirst); + format_set_font_size(headerFirst, 18); + format_set_align(headerFirst, LXW_ALIGN_CENTER); + format_set_align(headerFirst, LXW_ALIGN_VERTICAL_CENTER); + format_set_border(headerFirst, LXW_BORDER_THIN); + format_set_bg_color(headerFirst, lxw_color_t(colorTitle)); + + // Header rest + lxw_format* headerRest = workbook_add_format(workbook); + format_set_bold(headerRest); + format_set_align(headerRest, LXW_ALIGN_CENTER); + format_set_font_size(headerRest, 16); + format_set_align(headerRest, LXW_ALIGN_VERTICAL_CENTER); + format_set_border(headerRest, LXW_BORDER_THIN); + format_set_bg_color(headerRest, lxw_color_t(colorOdd)); + + // Header small + lxw_format* headerSmall = workbook_add_format(workbook); + format_set_bold(headerSmall); + format_set_align(headerSmall, LXW_ALIGN_LEFT); + format_set_font_size(headerSmall, 12); + format_set_border(headerSmall, LXW_BORDER_THIN); + format_set_align(headerSmall, LXW_ALIGN_VERTICAL_CENTER); + format_set_bg_color(headerSmall, lxw_color_t(colorOdd)); + + // Summary style + lxw_format* summaryStyle = workbook_add_format(workbook); + format_set_bold(summaryStyle); + format_set_font_size(summaryStyle, 16); + format_set_border(summaryStyle, LXW_BORDER_THIN); + format_set_align(summaryStyle, LXW_ALIGN_VERTICAL_CENTER); + + styles["headerFirst"] = headerFirst; + styles["headerRest"] = headerRest; + styles["headerSmall"] = headerSmall; + styles["summaryStyle"] = summaryStyle; + } + + void ReportExcel::setProperties() + { + char line[data["title"].get().size() + 1]; + strcpy(line, data["title"].get().c_str()); + lxw_doc_properties properties = { + .title = line, + .subject = "Machine learning results", + .author = "Ricardo Montañana Gómez", + .manager = "Dr. J. A. Gámez, Dr. J. M. Puerta", + .company = "UCLM", + .comments = "Created with libxlsxwriter and c++", + }; + workbook_set_properties(workbook, &properties); + } + void ReportExcel::createFile() { - doc.create(Paths::excel() + "some_results.xlsx"); - wks = doc.workbook().worksheet("Sheet1"); - wks.setName(data["model"].get()); + if (workbook == NULL) { + workbook = workbook_new((Paths::excel() + fileName).c_str()); + } + const string name = data["model"].get(); + string suffix = ""; + string efectiveName; + int num = 1; + // Create a sheet with the name of the model + while (true) { + efectiveName = name + suffix; + if (workbook_get_worksheet_by_name(workbook, efectiveName.c_str())) { + suffix = to_string(++num); + } else { + worksheet = workbook_add_worksheet(workbook, efectiveName.c_str()); + break; + } + if (num > 100) { + throw invalid_argument("Couldn't create sheet " + efectiveName); + } + } + cout << "Adding sheet " << efectiveName << " to " << Paths::excel() + fileName << endl; + setProperties(); + createFormats(); + formatColumns(); } void ReportExcel::closeFile() { - doc.save(); - doc.close(); + workbook_close(workbook); } void ReportExcel::header() @@ -32,45 +210,62 @@ namespace platform { locale::global(mylocale); cout.imbue(mylocale); stringstream oss; - wks.cell("A1").value().set( - "Report " + data["model"].get() + " ver. " + data["version"].get() + " with " + - to_string(data["folds"].get()) + " Folds cross validation and " + to_string(data["seeds"].size()) + - " random seeds. " + data["date"].get() + " " + data["time"].get()); - wks.cell("A2").value() = data["title"].get(); - wks.cell("A3").value() = "Random seeds: " + fromVector("seeds") + " Stratified: " + - (data["stratified"].get() ? "True" : "False"); - oss << "Execution took " << setprecision(2) << fixed << data["duration"].get() << " seconds, " - << data["duration"].get() / 3600 << " hours, on " << data["platform"].get(); - wks.cell("A4").value() = oss.str(); - wks.cell("A5").value() = "Score is " + data["score_name"].get(); + string message = data["model"].get() + " ver. " + data["version"].get() + " " + + data["language"].get() + " ver. " + data["language_version"].get() + + " with " + to_string(data["folds"].get()) + " Folds cross validation and " + to_string(data["seeds"].size()) + + " random seeds. " + data["date"].get() + " " + data["time"].get(); + worksheet_merge_range(worksheet, 0, 0, 0, 12, message.c_str(), styles["headerFirst"]); + worksheet_merge_range(worksheet, 1, 0, 1, 12, data["title"].get().c_str(), styles["headerRest"]); + worksheet_merge_range(worksheet, 2, 0, 3, 0, ("Score is " + data["score_name"].get()).c_str(), styles["headerRest"]); + worksheet_merge_range(worksheet, 2, 1, 3, 3, "Execution time", styles["headerRest"]); + oss << setprecision(2) << fixed << data["duration"].get() << " s"; + worksheet_merge_range(worksheet, 2, 4, 2, 5, oss.str().c_str(), styles["headerRest"]); + oss.str(""); + oss.clear(); + oss << setprecision(2) << fixed << data["duration"].get() / 3600 << " h"; + worksheet_merge_range(worksheet, 3, 4, 3, 5, oss.str().c_str(), styles["headerRest"]); + worksheet_merge_range(worksheet, 2, 6, 3, 7, "Platform", styles["headerRest"]); + worksheet_merge_range(worksheet, 2, 8, 3, 9, data["platform"].get().c_str(), styles["headerRest"]); + worksheet_merge_range(worksheet, 2, 10, 2, 12, ("Random seeds: " + fromVector("seeds")).c_str(), styles["headerSmall"]); + oss.str(""); + oss.clear(); + oss << "Stratified: " << (data["stratified"].get() ? "True" : "False"); + worksheet_merge_range(worksheet, 3, 10, 3, 11, oss.str().c_str(), styles["headerSmall"]); + oss.str(""); + oss.clear(); + oss << "Discretized: " << (data["discretized"].get() ? "True" : "False"); + worksheet_write_string(worksheet, 3, 12, oss.str().c_str(), styles["headerSmall"]); } void ReportExcel::body() { auto head = vector( - { "Dataset", "Samples", "Features", "Classes", "Nodes", "Edges", "States", "Score", "Score Std.", "Time", + { "Dataset", "Samples", "Features", "Classes", "Nodes", "Edges", "States", "Score", "Score Std.", "St.", "Time", "Time Std.", "Hyperparameters" }); - int col = 1; + int col = 0; for (const auto& item : head) { - wks.cell(8, col++).value() = item; + writeString(5, col++, item, "bodyHeader"); } - int row = 9; - col = 1; + row = 6; + col = 0; + int hypSize = 22; json lastResult; double totalScore = 0.0; string hyperparameters; for (const auto& r : data["results"]) { - wks.cell(row, col).value() = r["dataset"].get(); - wks.cell(row, col + 1).value() = r["samples"].get(); - wks.cell(row, col + 2).value() = r["features"].get(); - wks.cell(row, col + 3).value() = r["classes"].get(); - wks.cell(row, col + 4).value() = r["nodes"].get(); - wks.cell(row, col + 5).value() = r["leaves"].get(); - wks.cell(row, col + 6).value() = r["depth"].get(); - wks.cell(row, col + 7).value() = r["score"].get(); - wks.cell(row, col + 8).value() = r["score_std"].get(); - wks.cell(row, col + 9).value() = r["time"].get(); - wks.cell(row, col + 10).value() = r["time_std"].get(); + writeString(row, col, r["dataset"].get(), "text"); + writeInt(row, col + 1, r["samples"].get(), "ints"); + writeInt(row, col + 2, r["features"].get(), "ints"); + writeInt(row, col + 3, r["classes"].get(), "ints"); + writeDouble(row, col + 4, r["nodes"].get(), "floats"); + writeDouble(row, col + 5, r["leaves"].get(), "floats"); + writeDouble(row, col + 6, r["depth"].get(), "floats"); + writeDouble(row, col + 7, r["score"].get(), "result"); + writeDouble(row, col + 8, r["score_std"].get(), "result"); + const string status = compareResult(r["dataset"].get(), r["score"].get()); + writeString(row, col + 9, status, "textCentered"); + writeDouble(row, col + 10, r["time"].get(), "time"); + writeDouble(row, col + 11, r["time_std"].get(), "time"); try { hyperparameters = r["hyperparameters"].get(); } @@ -79,31 +274,57 @@ namespace platform { oss << r["hyperparameters"]; hyperparameters = oss.str(); } - wks.cell(row, col + 11).value() = hyperparameters; + if (hyperparameters.size() > hypSize) { + hypSize = hyperparameters.size(); + } + writeString(row, col + 12, hyperparameters, "text"); lastResult = r; totalScore += r["score"].get(); row++; + } + // Set the right column width of hyperparameters with the maximum length + worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL); + // Show totals if only one dataset is present in the result if (data["results"].size() == 1) { for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) { row++; col = 1; - wks.cell(row, col).value() = group; + writeString(row, col, group, "text"); for (double item : lastResult[group]) { - wks.cell(row, ++col).value() = item; + string style = group.find("scores") != string::npos ? "result" : "time"; + writeDouble(row, ++col, item, style); } } + // Set with of columns to show those totals completely + worksheet_set_column(worksheet, 1, 1, 12, NULL); + for (int i = 2; i < 7; ++i) { + // doesn't work with from col to col, so... + worksheet_set_column(worksheet, i, i, 15, NULL); + } } else { footer(totalScore, row); } } + void ReportExcel::showSummary() + { + for (const auto& item : summary) { + worksheet_write_string(worksheet, row + 2, 1, item.first.c_str(), styles["summaryStyle"]); + worksheet_write_number(worksheet, row + 2, 2, item.second, styles["summaryStyle"]); + worksheet_merge_range(worksheet, row + 2, 3, row + 2, 5, meaning.at(item.first).c_str(), styles["summaryStyle"]); + row += 1; + } + } + void ReportExcel::footer(double totalScore, int row) { + showSummary(); + row += 4 + summary.size(); auto score = data["score_name"].get(); if (score == BestResult::scoreName()) { - wks.cell(row + 2, 1).value() = score + " compared to " + BestResult::title() + " .: "; - wks.cell(row + 2, 5).value() = totalScore / BestResult::score(); + worksheet_merge_range(worksheet, row, 1, row, 5, (score + " compared to " + BestResult::title() + " .:").c_str(), efectiveStyle("text")); + writeDouble(row, 6, totalScore / BestResult::score(), "result"); } } } \ No newline at end of file diff --git a/src/Platform/ReportExcel.h b/src/Platform/ReportExcel.h index 3700681..c5d462f 100644 --- a/src/Platform/ReportExcel.h +++ b/src/Platform/ReportExcel.h @@ -1,25 +1,42 @@ #ifndef REPORTEXCEL_H #define REPORTEXCEL_H -#include +#include +#include "xlsxwriter.h" #include "ReportBase.h" -#include "Paths.h" #include "Colors.h" namespace platform { using namespace std; - using namespace OpenXLSX; const int MAXLL = 128; - class ReportExcel : public ReportBase{ + + class ReportExcel : public ReportBase { public: - explicit ReportExcel(json data_) : ReportBase(data_) {createFile();}; - virtual ~ReportExcel() {closeFile();}; + explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook); + lxw_workbook* getWorkbook(); private: + void writeString(int row, int col, const string& text, const string& style = ""); + void writeInt(int row, int col, const int number, const string& style = ""); + void writeDouble(int row, int col, const double number, const string& style = ""); + void formatColumns(); + void createFormats(); + void setProperties(); void createFile(); void closeFile(); - XLDocument doc; - XLWorksheet wks; + void showSummary(); + lxw_workbook* workbook; + lxw_worksheet* worksheet; + map styles; + int row; + int normalSize; //font size for report body + uint32_t colorTitle; + uint32_t colorOdd; + uint32_t colorEven; + const string fileName = "some_results.xlsx"; void header() override; void body() override; void footer(double totalScore, int row); + void createStyle(const string& name, lxw_format* style, bool odd); + void addColor(lxw_format* style, bool odd); + lxw_format* efectiveStyle(const string& name); }; }; #endif // !REPORTEXCEL_H \ No newline at end of file diff --git a/src/Platform/Results.cc b/src/Platform/Results.cc index 3566ab7..725952c 100644 --- a/src/Platform/Results.cc +++ b/src/Platform/Results.cc @@ -104,15 +104,17 @@ namespace platform { cout << "Invalid index" << endl; return -1; } - void Results::report(const int index, const bool excelReport) const + void Results::report(const int index, const bool excelReport) { cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl; auto data = files.at(index).load(); if (excelReport) { - ReportExcel reporter(data); + ReportExcel reporter(data, compare, workbook); reporter.show(); + openExcel = true; + workbook = reporter.getWorkbook(); } else { - ReportConsole reporter(data); + ReportConsole reporter(data, compare); reporter.show(); } } @@ -124,7 +126,7 @@ namespace platform { return; } cout << Colors::YELLOW() << "Showing " << files.at(index).getFilename() << endl; - ReportConsole reporter(data, idx); + ReportConsole reporter(data, compare, idx); reporter.show(); } void Results::menu() @@ -132,9 +134,21 @@ namespace platform { char option; int index; bool finished = false; + string color, context; string filename, line, options = "qldhsre"; while (!finished) { - cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): "; + if (indexList) { + color = Colors::GREEN(); + context = " (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): "; + options = "qldhsre"; + } else { + color = Colors::MAGENTA(); + context = " (quit='q', list='l'): "; + options = "ql"; + } + cout << Colors::RESET() << color; + + cout << "Choose option " << context; getline(cin, line); if (line.size() == 0) continue; @@ -148,6 +162,7 @@ namespace platform { if (all_of(line.begin(), line.end(), ::isdigit)) { int idx = stoi(line); if (indexList) { + // The value is about the files list index = idx; if (index >= 0 && index < files.size()) { report(index, false); @@ -155,6 +170,7 @@ namespace platform { continue; } } else { + // The value is about the result showed on screen showIndex(index, idx); continue; } @@ -281,6 +297,9 @@ namespace platform { sortDate(); show(); menu(); + if (openExcel) { + workbook_close(workbook); + } cout << "Done!" << endl; } diff --git a/src/Platform/Results.h b/src/Platform/Results.h index 3f5655a..60748ba 100644 --- a/src/Platform/Results.h +++ b/src/Platform/Results.h @@ -1,5 +1,6 @@ #ifndef RESULTS_H #define RESULTS_H +#include "xlsxwriter.h" #include #include #include @@ -34,7 +35,11 @@ namespace platform { }; class Results { public: - Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial) : path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial) { load(); }; + Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial, bool compare) : + path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial), compare(compare) + { + load(); + }; void manage(); private: string path; @@ -44,10 +49,13 @@ namespace platform { bool complete; bool partial; bool indexList = true; + bool openExcel = false; + bool compare; + lxw_workbook* workbook = NULL; vector files; void load(); // Loads the list of results void show() const; - void report(const int index, const bool excelReport) const; + void report(const int index, const bool excelReport); void showIndex(const int index, const int idx) const; int getIndex(const string& intent) const; void menu(); diff --git a/src/Platform/main.cc b/src/Platform/main.cc index 9f8e00b..a122ad2 100644 --- a/src/Platform/main.cc +++ b/src/Platform/main.cc @@ -87,7 +87,7 @@ int main(int argc, char** argv) auto stratified = program.get("stratified"); auto n_folds = program.get("folds"); auto seeds = program.get>("seeds"); - auto hyperparameters =program.get("hyperparameters"); + auto hyperparameters = program.get("hyperparameters"); vector filesToTest; auto datasets = platform::Datasets(path, true, platform::ARFF); auto title = program.get("title"); @@ -102,7 +102,7 @@ int main(int argc, char** argv) } filesToTest.push_back(file_name); } else { - filesToTest = platform::Datasets(path, true, platform::ARFF).getNames(); + filesToTest = datasets.getNames(); saveResults = true; } /* diff --git a/src/Platform/manage.cc b/src/Platform/manage.cc index aec19e7..cf699d6 100644 --- a/src/Platform/manage.cc +++ b/src/Platform/manage.cc @@ -14,6 +14,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied"); program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true); program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true); + program.add_argument("--compare").help("Compare with best results").default_value(false).implicit_value(true); try { program.parse_args(argc, argv); auto number = program.get("number"); @@ -24,6 +25,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv) auto score = program.get("score"); auto complete = program.get("complete"); auto partial = program.get("partial"); + auto compare = program.get("compare"); } catch (const exception& err) { cerr << err.what() << endl; @@ -41,9 +43,10 @@ int main(int argc, char** argv) auto score = program.get("score"); auto complete = program.get("complete"); auto partial = program.get("partial"); + auto compare = program.get("compare"); if (complete) partial = false; - auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial); + auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial, compare); results.manage(); return 0; }