First approach

This commit is contained in:
Ricardo Montañana Gómez 2023-09-18 23:26:22 +02:00
parent 501ea0ab4e
commit 06fb135526
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
8 changed files with 150 additions and 45 deletions

3
.gitmodules vendored
View File

@ -13,3 +13,6 @@
[submodule "lib/openXLSX"]
path = lib/openXLSX
url = https://github.com/troldal/OpenXLSX.git
[submodule "lib/libxlsxwriter"]
path = lib/libxlsxwriter
url = https://github.com/jmcnamara/libxlsxwriter.git

View File

@ -55,6 +55,7 @@ add_git_submodule("lib/mdlp")
add_git_submodule("lib/argparse")
add_git_submodule("lib/json")
add_git_submodule("lib/openXLSX")
#add_git_submodule("lib/libxlsxwriter")
# Subdirectories
# --------------
@ -64,6 +65,9 @@ add_subdirectory(src/BayesNet)
add_subdirectory(src/Platform)
add_subdirectory(sample)
#find_package(PkgConfig REQUIRED)
#pkg_check_modules(Xlsxwriter REQUIRED IMPORTED_TARGET libxlsxwriter)
file(GLOB BayesNet_HEADERS CONFIGURE_DEPENDS ${BayesNet_SOURCE_DIR}/src/BayesNet/*.h ${BayesNet_SOURCE_DIR}/BayesNet/*.hpp)
file(GLOB BayesNet_SOURCES CONFIGURE_DEPENDS ${BayesNet_SOURCE_DIR}/src/BayesNet/*.cc ${BayesNet_SOURCE_DIR}/src/BayesNet/*.cpp)
file(GLOB Platform_SOURCES CONFIGURE_DEPENDS ${BayesNet_SOURCE_DIR}/src/Platform/*.cc ${BayesNet_SOURCE_DIR}/src/Platform/*.cpp)

View File

@ -2,4 +2,36 @@
Bayesian Network Classifier with libtorch from scratch
## 0. Setup
### libxlswriter
Before compiling BayesNet.
```bash
cd lib/libxlsxwriter
make
sudo make install
```
It has to be installed in /usr/local/lib otherwise CMakeLists.txt has to be modified accordingly
Environment variable has to be set:
```bash
export LD_LIBRARY_PATH=/usr/local/lib
```
### Release
```bash
make release
```
### Debug & Tests
```bash
make debug
```
## 1. Introduction

@ -1 +1 @@
Subproject commit 4acc51828f7f93f3b2058a63f54d112af4034503
Subproject commit 9c541ca72e7857dec71d8a41b97e42c2f1c92602

1
lib/libxlsxwriter Submodule

@ -0,0 +1 @@
Subproject commit 44e72c5862f9d549453a4ff6e8ceab0da19705e5

View File

@ -9,8 +9,8 @@ add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc Repor
add_executable(list list.cc platformUtils Datasets.cc)
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux")
target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX stdc++fs)
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so stdc++fs)
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so)
else()
target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX)
endif()
target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}")

View File

@ -12,18 +12,72 @@ namespace platform {
string do_grouping() const { return "\03"; }
};
void ReportExcel::writeString(int row, int col, const string& text, const string& style)
{
lxw_format* efectiveStyle = style == "" ? NULL : styles[style];
worksheet_write_string(worksheet, row, col, text.c_str(), efectiveStyle);
}
void ReportExcel::writeInt(int row, int col, const int number, const string& style)
{
lxw_format* efectiveStyle = style == "" ? NULL : styles[style];
worksheet_write_number(worksheet, row, col, number, efectiveStyle);
}
void ReportExcel::writeDouble(int row, int col, const double number, const string& style)
{
lxw_format* efectiveStyle = style == "" ? NULL : styles[style];
worksheet_write_number(worksheet, row, col, number, efectiveStyle);
}
void ReportExcel::formatHeader()
{
worksheet_freeze_panes(worksheet, 8, 0);
}
void ReportExcel::formatBody()
{
}
void ReportExcel::formatFooter()
{
}
void ReportExcel::createFormats()
{
lxw_format* bold = workbook_add_format(workbook);
format_set_bold(bold);
lxw_format* result = workbook_add_format(workbook);
format_set_num_format(result, "0.0000000");
lxw_format* timeStyle = workbook_add_format(workbook);
format_set_num_format(timeStyle, "#,##0.00");
lxw_format* ints = workbook_add_format(workbook);
format_set_num_format(ints, "###,###");
lxw_format* floats = workbook_add_format(workbook);
format_set_num_format(floats, "#,###.00");
styles["bold"] = bold;
styles["result"] = result;
styles["time"] = timeStyle;
styles["ints"] = ints;
styles["floats"] = floats
}
void ReportExcel::createFile()
{
doc.create(Paths::excel() + "some_results.xlsx");
wks = doc.workbook().worksheet("Sheet1");
wks.setName(data["model"].get<string>());
workbook = workbook_new((Paths::excel() + "some_results.xlsx").c_str());
const string name = data["model"].get<string>();
worksheet = workbook_add_worksheet(workbook, name.c_str());
createFormats();
}
void ReportExcel::closeFile()
{
doc.save();
doc.close();
workbook_close(workbook);
}
void ReportExcel::header()
@ -32,17 +86,17 @@ namespace platform {
locale::global(mylocale);
cout.imbue(mylocale);
stringstream oss;
wks.cell("A1").value().set(
"Report " + data["model"].get<string>() + " ver. " + data["version"].get<string>() + " with " +
writeString(0, 0, "Report " + data["model"].get<string>() + " ver. " + data["version"].get<string>() + " with " +
to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) +
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>());
wks.cell("A2").value() = data["title"].get<string>();
wks.cell("A3").value() = "Random seeds: " + fromVector("seeds") + " Stratified: " +
(data["stratified"].get<bool>() ? "True" : "False");
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>(), "bold");
writeString(1, 0, data["title"].get<string>());
writeString(2, 0, "Random seeds: " + fromVector("seeds") + " Stratified: " +
(data["stratified"].get<bool>() ? "True" : "False"));
oss << "Execution took " << setprecision(2) << fixed << data["duration"].get<float>() << " seconds, "
<< data["duration"].get<float>() / 3600 << " hours, on " << data["platform"].get<string>();
wks.cell("A4").value() = oss.str();
wks.cell("A5").value() = "Score is " + data["score_name"].get<string>();
writeString(3, 0, oss.str());
writeString(4, 0, "Score is " + data["score_name"].get<string>());
formatHeader();
}
void ReportExcel::body()
@ -52,7 +106,7 @@ namespace platform {
"Time Std.", "Hyperparameters" });
int col = 1;
for (const auto& item : head) {
wks.cell(8, col++).value() = item;
writeString(8, col++, item);
}
int row = 9;
col = 1;
@ -60,17 +114,17 @@ namespace platform {
double totalScore = 0.0;
string hyperparameters;
for (const auto& r : data["results"]) {
wks.cell(row, col).value() = r["dataset"].get<string>();
wks.cell(row, col + 1).value() = r["samples"].get<int>();
wks.cell(row, col + 2).value() = r["features"].get<int>();
wks.cell(row, col + 3).value() = r["classes"].get<int>();
wks.cell(row, col + 4).value() = r["nodes"].get<float>();
wks.cell(row, col + 5).value() = r["leaves"].get<float>();
wks.cell(row, col + 6).value() = r["depth"].get<float>();
wks.cell(row, col + 7).value() = r["score"].get<double>();
wks.cell(row, col + 8).value() = r["score_std"].get<double>();
wks.cell(row, col + 9).value() = r["time"].get<double>();
wks.cell(row, col + 10).value() = r["time_std"].get<double>();
writeString(row, col, r["dataset"].get<string>());
writeInt(row, col + 1, r["samples"].get<int>(), "ints");
writeInt(row, col + 2, r["features"].get<int>(), "ints");
writeInt(row, col + 3, r["classes"].get<int>(), "ints");
writeDouble(row, col + 4, r["nodes"].get<float>(), "floats");
writeDouble(row, col + 5, r["leaves"].get<float>(), "floats");
writeDouble(row, col + 6, r["depth"].get<double>(), "floats");
writeDouble(row, col + 7, r["score"].get<double>(), "result");
writeDouble(row, col + 8, r["score_std"].get<double>(), "result");
writeDouble(row, col + 9, r["time"].get<double>(), "time");
writeDouble(row, col + 10, r["time_std"].get<double>(), "time");
try {
hyperparameters = r["hyperparameters"].get<string>();
}
@ -79,7 +133,7 @@ namespace platform {
oss << r["hyperparameters"];
hyperparameters = oss.str();
}
wks.cell(row, col + 11).value() = hyperparameters;
writeString(row, col + 11, hyperparameters);
lastResult = r;
totalScore += r["score"].get<double>();
row++;
@ -88,22 +142,24 @@ namespace platform {
for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
row++;
col = 1;
wks.cell(row, col).value() = group;
writeString(row, col, group);
for (double item : lastResult[group]) {
wks.cell(row, ++col).value() = item;
writeDouble(row, ++col, item);
}
}
} else {
footer(totalScore, row);
}
formatBody();
}
void ReportExcel::footer(double totalScore, int row)
{
auto score = data["score_name"].get<string>();
if (score == BestResult::scoreName()) {
wks.cell(row + 2, 1).value() = score + " compared to " + BestResult::title() + " .: ";
wks.cell(row + 2, 5).value() = totalScore / BestResult::score();
writeString(row + 2, 1, score + " compared to " + BestResult::title() + " .: ");
writeDouble(row + 2, 7, totalScore / BestResult::score(), "result");
}
formatFooter();
}
}

View File

@ -1,25 +1,34 @@
#ifndef REPORTEXCEL_H
#define REPORTEXCEL_H
#include <OpenXLSX.hpp>
#include<map>
#include "xlsxwriter.h"
#include "ReportBase.h"
#include "Paths.h"
#include "Colors.h"
namespace platform {
using namespace std;
using namespace OpenXLSX;
const int MAXLL = 128;
class ReportExcel : public ReportBase{
class ReportExcel : public ReportBase {
public:
explicit ReportExcel(json data_) : ReportBase(data_) {createFile();};
virtual ~ReportExcel() {closeFile();};
private:
explicit ReportExcel(json data_) : ReportBase(data_) { createFile(); };
virtual ~ReportExcel() { closeFile(); };
protected:
void writeString(int row, int col, const string& text, const string& style = "");
void writeInt(int row, int col, const int number, const string& style = "");
void writeDouble(int row, int col, const double number, const string& style = "");
void formatHeader();
void formatBody();
void formatFooter();
void createFormats();
private: galeote
void createFile();
void closeFile();
XLDocument doc;
XLWorksheet wks;
void header() override;
void body() override;
void footer(double totalScore, int row);
void closeFile();
lxw_workbook* workbook;
lxw_worksheet* worksheet;
map<string, lxw_format*> styles;
void header() override;
void body() override;
void footer(double totalScore, int row);
};
};
#endif // !REPORTEXCEL_H