diff --git a/.gitmodules b/.gitmodules index 626d10f..235f4f1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -13,3 +13,6 @@ [submodule "lib/openXLSX"] path = lib/openXLSX url = https://github.com/troldal/OpenXLSX.git +[submodule "lib/libxlsxwriter"] + path = lib/libxlsxwriter + url = https://github.com/jmcnamara/libxlsxwriter.git diff --git a/CMakeLists.txt b/CMakeLists.txt index be56d34..d74de50 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,6 +55,7 @@ add_git_submodule("lib/mdlp") add_git_submodule("lib/argparse") add_git_submodule("lib/json") add_git_submodule("lib/openXLSX") +#add_git_submodule("lib/libxlsxwriter") # Subdirectories # -------------- @@ -64,6 +65,9 @@ add_subdirectory(src/BayesNet) add_subdirectory(src/Platform) add_subdirectory(sample) +#find_package(PkgConfig REQUIRED) +#pkg_check_modules(Xlsxwriter REQUIRED IMPORTED_TARGET libxlsxwriter) + file(GLOB BayesNet_HEADERS CONFIGURE_DEPENDS ${BayesNet_SOURCE_DIR}/src/BayesNet/*.h ${BayesNet_SOURCE_DIR}/BayesNet/*.hpp) file(GLOB BayesNet_SOURCES CONFIGURE_DEPENDS ${BayesNet_SOURCE_DIR}/src/BayesNet/*.cc ${BayesNet_SOURCE_DIR}/src/BayesNet/*.cpp) file(GLOB Platform_SOURCES CONFIGURE_DEPENDS ${BayesNet_SOURCE_DIR}/src/Platform/*.cc ${BayesNet_SOURCE_DIR}/src/Platform/*.cpp) diff --git a/README.md b/README.md index c95d141..266bb8a 100644 --- a/README.md +++ b/README.md @@ -2,4 +2,36 @@ Bayesian Network Classifier with libtorch from scratch +## 0. Setup + +### libxlswriter + +Before compiling BayesNet. + +```bash +cd lib/libxlsxwriter +make +sudo make install +``` + +It has to be installed in /usr/local/lib otherwise CMakeLists.txt has to be modified accordingly + +Environment variable has to be set: + +```bash + export LD_LIBRARY_PATH=/usr/local/lib + ``` + +### Release + +```bash +make release +``` + +### Debug & Tests + +```bash +make debug +``` + ## 1. Introduction diff --git a/lib/catch2 b/lib/catch2 index 4acc518..9c541ca 160000 --- a/lib/catch2 +++ b/lib/catch2 @@ -1 +1 @@ -Subproject commit 4acc51828f7f93f3b2058a63f54d112af4034503 +Subproject commit 9c541ca72e7857dec71d8a41b97e42c2f1c92602 diff --git a/lib/libxlsxwriter b/lib/libxlsxwriter new file mode 160000 index 0000000..44e72c5 --- /dev/null +++ b/lib/libxlsxwriter @@ -0,0 +1 @@ +Subproject commit 44e72c5862f9d549453a4ff6e8ceab0da19705e5 diff --git a/src/Platform/CMakeLists.txt b/src/Platform/CMakeLists.txt index b5dc5b9..b885792 100644 --- a/src/Platform/CMakeLists.txt +++ b/src/Platform/CMakeLists.txt @@ -9,8 +9,8 @@ add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc Repor add_executable(list list.cc platformUtils Datasets.cc) target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux") - target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX stdc++fs) + target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so stdc++fs) + target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so) else() - target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX) endif() target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Platform/ReportExcel.cc b/src/Platform/ReportExcel.cc index 0ed5854..3c011aa 100644 --- a/src/Platform/ReportExcel.cc +++ b/src/Platform/ReportExcel.cc @@ -12,18 +12,72 @@ namespace platform { string do_grouping() const { return "\03"; } }; + void ReportExcel::writeString(int row, int col, const string& text, const string& style) + { + lxw_format* efectiveStyle = style == "" ? NULL : styles[style]; + worksheet_write_string(worksheet, row, col, text.c_str(), efectiveStyle); + } + void ReportExcel::writeInt(int row, int col, const int number, const string& style) + { + lxw_format* efectiveStyle = style == "" ? NULL : styles[style]; + worksheet_write_number(worksheet, row, col, number, efectiveStyle); + } + void ReportExcel::writeDouble(int row, int col, const double number, const string& style) + { + lxw_format* efectiveStyle = style == "" ? NULL : styles[style]; + worksheet_write_number(worksheet, row, col, number, efectiveStyle); + } + + void ReportExcel::formatHeader() + { + worksheet_freeze_panes(worksheet, 8, 0); + } + + void ReportExcel::formatBody() + { + + } + + void ReportExcel::formatFooter() + { + + } + + void ReportExcel::createFormats() + { + lxw_format* bold = workbook_add_format(workbook); + format_set_bold(bold); + + lxw_format* result = workbook_add_format(workbook); + format_set_num_format(result, "0.0000000"); + + lxw_format* timeStyle = workbook_add_format(workbook); + format_set_num_format(timeStyle, "#,##0.00"); + + lxw_format* ints = workbook_add_format(workbook); + format_set_num_format(ints, "###,###"); + + lxw_format* floats = workbook_add_format(workbook); + format_set_num_format(floats, "#,###.00"); + + styles["bold"] = bold; + styles["result"] = result; + styles["time"] = timeStyle; + styles["ints"] = ints; + styles["floats"] = floats + } void ReportExcel::createFile() { - doc.create(Paths::excel() + "some_results.xlsx"); - wks = doc.workbook().worksheet("Sheet1"); - wks.setName(data["model"].get()); + workbook = workbook_new((Paths::excel() + "some_results.xlsx").c_str()); + const string name = data["model"].get(); + worksheet = workbook_add_worksheet(workbook, name.c_str()); + createFormats(); } void ReportExcel::closeFile() { - doc.save(); - doc.close(); + workbook_close(workbook); } void ReportExcel::header() @@ -32,17 +86,17 @@ namespace platform { locale::global(mylocale); cout.imbue(mylocale); stringstream oss; - wks.cell("A1").value().set( - "Report " + data["model"].get() + " ver. " + data["version"].get() + " with " + + writeString(0, 0, "Report " + data["model"].get() + " ver. " + data["version"].get() + " with " + to_string(data["folds"].get()) + " Folds cross validation and " + to_string(data["seeds"].size()) + - " random seeds. " + data["date"].get() + " " + data["time"].get()); - wks.cell("A2").value() = data["title"].get(); - wks.cell("A3").value() = "Random seeds: " + fromVector("seeds") + " Stratified: " + - (data["stratified"].get() ? "True" : "False"); + " random seeds. " + data["date"].get() + " " + data["time"].get(), "bold"); + writeString(1, 0, data["title"].get()); + writeString(2, 0, "Random seeds: " + fromVector("seeds") + " Stratified: " + + (data["stratified"].get() ? "True" : "False")); oss << "Execution took " << setprecision(2) << fixed << data["duration"].get() << " seconds, " << data["duration"].get() / 3600 << " hours, on " << data["platform"].get(); - wks.cell("A4").value() = oss.str(); - wks.cell("A5").value() = "Score is " + data["score_name"].get(); + writeString(3, 0, oss.str()); + writeString(4, 0, "Score is " + data["score_name"].get()); + formatHeader(); } void ReportExcel::body() @@ -52,7 +106,7 @@ namespace platform { "Time Std.", "Hyperparameters" }); int col = 1; for (const auto& item : head) { - wks.cell(8, col++).value() = item; + writeString(8, col++, item); } int row = 9; col = 1; @@ -60,17 +114,17 @@ namespace platform { double totalScore = 0.0; string hyperparameters; for (const auto& r : data["results"]) { - wks.cell(row, col).value() = r["dataset"].get(); - wks.cell(row, col + 1).value() = r["samples"].get(); - wks.cell(row, col + 2).value() = r["features"].get(); - wks.cell(row, col + 3).value() = r["classes"].get(); - wks.cell(row, col + 4).value() = r["nodes"].get(); - wks.cell(row, col + 5).value() = r["leaves"].get(); - wks.cell(row, col + 6).value() = r["depth"].get(); - wks.cell(row, col + 7).value() = r["score"].get(); - wks.cell(row, col + 8).value() = r["score_std"].get(); - wks.cell(row, col + 9).value() = r["time"].get(); - wks.cell(row, col + 10).value() = r["time_std"].get(); + writeString(row, col, r["dataset"].get()); + writeInt(row, col + 1, r["samples"].get(), "ints"); + writeInt(row, col + 2, r["features"].get(), "ints"); + writeInt(row, col + 3, r["classes"].get(), "ints"); + writeDouble(row, col + 4, r["nodes"].get(), "floats"); + writeDouble(row, col + 5, r["leaves"].get(), "floats"); + writeDouble(row, col + 6, r["depth"].get(), "floats"); + writeDouble(row, col + 7, r["score"].get(), "result"); + writeDouble(row, col + 8, r["score_std"].get(), "result"); + writeDouble(row, col + 9, r["time"].get(), "time"); + writeDouble(row, col + 10, r["time_std"].get(), "time"); try { hyperparameters = r["hyperparameters"].get(); } @@ -79,7 +133,7 @@ namespace platform { oss << r["hyperparameters"]; hyperparameters = oss.str(); } - wks.cell(row, col + 11).value() = hyperparameters; + writeString(row, col + 11, hyperparameters); lastResult = r; totalScore += r["score"].get(); row++; @@ -88,22 +142,24 @@ namespace platform { for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) { row++; col = 1; - wks.cell(row, col).value() = group; + writeString(row, col, group); for (double item : lastResult[group]) { - wks.cell(row, ++col).value() = item; + writeDouble(row, ++col, item); } } } else { footer(totalScore, row); } + formatBody(); } void ReportExcel::footer(double totalScore, int row) { auto score = data["score_name"].get(); if (score == BestResult::scoreName()) { - wks.cell(row + 2, 1).value() = score + " compared to " + BestResult::title() + " .: "; - wks.cell(row + 2, 5).value() = totalScore / BestResult::score(); + writeString(row + 2, 1, score + " compared to " + BestResult::title() + " .: "); + writeDouble(row + 2, 7, totalScore / BestResult::score(), "result"); } + formatFooter(); } } \ No newline at end of file diff --git a/src/Platform/ReportExcel.h b/src/Platform/ReportExcel.h index 3700681..69d9f69 100644 --- a/src/Platform/ReportExcel.h +++ b/src/Platform/ReportExcel.h @@ -1,25 +1,34 @@ #ifndef REPORTEXCEL_H #define REPORTEXCEL_H -#include +#include +#include "xlsxwriter.h" #include "ReportBase.h" #include "Paths.h" #include "Colors.h" namespace platform { using namespace std; - using namespace OpenXLSX; const int MAXLL = 128; - class ReportExcel : public ReportBase{ + class ReportExcel : public ReportBase { public: - explicit ReportExcel(json data_) : ReportBase(data_) {createFile();}; - virtual ~ReportExcel() {closeFile();}; - private: + explicit ReportExcel(json data_) : ReportBase(data_) { createFile(); }; + virtual ~ReportExcel() { closeFile(); }; + protected: + void writeString(int row, int col, const string& text, const string& style = ""); + void writeInt(int row, int col, const int number, const string& style = ""); + void writeDouble(int row, int col, const double number, const string& style = ""); + void formatHeader(); + void formatBody(); + void formatFooter(); + void createFormats(); + private: galeote void createFile(); - void closeFile(); - XLDocument doc; - XLWorksheet wks; - void header() override; - void body() override; - void footer(double totalScore, int row); + void closeFile(); + lxw_workbook* workbook; + lxw_worksheet* worksheet; + map styles; + void header() override; + void body() override; + void footer(double totalScore, int row); }; }; #endif // !REPORTEXCEL_H \ No newline at end of file