Merge pull request 'Exchange OpenXLSX to libxlsxwriter' (#8) from libxlsxwriter into main

Add multiple sheets to excel file
Add format and color to sheets
Add comparison with ZeroR
Add comparison with Best Results
Separate contextual menu from general in manage
This commit is contained in:
Ricardo Montañana Gómez 2023-09-20 11:17:16 +00:00
commit fc81730dfc
19 changed files with 481 additions and 184 deletions

1
.gitignore vendored
View File

@ -36,3 +36,4 @@ build/
cmake-build*/** cmake-build*/**
.idea .idea
puml/** puml/**
.vscode/settings.json

6
.gitmodules vendored
View File

@ -10,6 +10,6 @@
[submodule "lib/json"] [submodule "lib/json"]
path = lib/json path = lib/json
url = https://github.com/nlohmann/json.git url = https://github.com/nlohmann/json.git
[submodule "lib/openXLSX"] [submodule "lib/libxlsxwriter"]
path = lib/openXLSX path = lib/libxlsxwriter
url = https://github.com/troldal/OpenXLSX.git url = https://github.com/jmcnamara/libxlsxwriter.git

109
.vscode/settings.json vendored
View File

@ -1,109 +0,0 @@
{
"files.associations": {
"*.rmd": "markdown",
"*.py": "python",
"vector": "cpp",
"__bit_reference": "cpp",
"__bits": "cpp",
"__config": "cpp",
"__debug": "cpp",
"__errc": "cpp",
"__hash_table": "cpp",
"__locale": "cpp",
"__mutex_base": "cpp",
"__node_handle": "cpp",
"__nullptr": "cpp",
"__split_buffer": "cpp",
"__string": "cpp",
"__threading_support": "cpp",
"__tuple": "cpp",
"array": "cpp",
"atomic": "cpp",
"bitset": "cpp",
"cctype": "cpp",
"chrono": "cpp",
"clocale": "cpp",
"cmath": "cpp",
"compare": "cpp",
"complex": "cpp",
"concepts": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdint": "cpp",
"cstdio": "cpp",
"cstdlib": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"exception": "cpp",
"initializer_list": "cpp",
"ios": "cpp",
"iosfwd": "cpp",
"istream": "cpp",
"limits": "cpp",
"locale": "cpp",
"memory": "cpp",
"mutex": "cpp",
"new": "cpp",
"optional": "cpp",
"ostream": "cpp",
"ratio": "cpp",
"sstream": "cpp",
"stdexcept": "cpp",
"streambuf": "cpp",
"string": "cpp",
"string_view": "cpp",
"system_error": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"typeinfo": "cpp",
"unordered_map": "cpp",
"variant": "cpp",
"algorithm": "cpp",
"iostream": "cpp",
"iomanip": "cpp",
"numeric": "cpp",
"set": "cpp",
"__tree": "cpp",
"deque": "cpp",
"list": "cpp",
"map": "cpp",
"unordered_set": "cpp",
"any": "cpp",
"condition_variable": "cpp",
"forward_list": "cpp",
"fstream": "cpp",
"stack": "cpp",
"thread": "cpp",
"__memory": "cpp",
"filesystem": "cpp",
"*.toml": "toml",
"utility": "cpp",
"__verbose_abort": "cpp",
"bit": "cpp",
"random": "cpp",
"*.tcc": "cpp",
"functional": "cpp",
"iterator": "cpp",
"memory_resource": "cpp",
"format": "cpp",
"valarray": "cpp",
"regex": "cpp",
"span": "cpp",
"cfenv": "cpp",
"cinttypes": "cpp",
"csetjmp": "cpp",
"future": "cpp",
"queue": "cpp",
"typeindex": "cpp",
"shared_mutex": "cpp",
"*.ipp": "cpp",
"cassert": "cpp",
"charconv": "cpp",
"source_location": "cpp",
"ranges": "cpp"
},
"cmake.configureOnOpen": false,
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
}

View File

@ -54,7 +54,6 @@ endif (ENABLE_CLANG_TIDY)
add_git_submodule("lib/mdlp") add_git_submodule("lib/mdlp")
add_git_submodule("lib/argparse") add_git_submodule("lib/argparse")
add_git_submodule("lib/json") add_git_submodule("lib/json")
add_git_submodule("lib/openXLSX")
# Subdirectories # Subdirectories
# -------------- # --------------

View File

@ -2,4 +2,36 @@
Bayesian Network Classifier with libtorch from scratch Bayesian Network Classifier with libtorch from scratch
## 0. Setup
### libxlswriter
Before compiling BayesNet.
```bash
cd lib/libxlsxwriter
make
sudo make install
```
It has to be installed in /usr/local/lib otherwise CMakeLists.txt has to be modified accordingly
Environment variable has to be set:
```bash
export LD_LIBRARY_PATH=/usr/local/lib
```
### Release
```bash
make release
```
### Debug & Tests
```bash
make debug
```
## 1. Introduction ## 1. Introduction

@ -1 +1 @@
Subproject commit 4acc51828f7f93f3b2058a63f54d112af4034503 Subproject commit 9c541ca72e7857dec71d8a41b97e42c2f1c92602

1
lib/libxlsxwriter Submodule

@ -0,0 +1 @@
Subproject commit 44e72c5862f9d549453a4ff6e8ceab0da19705e5

@ -1 +0,0 @@
Subproject commit b80da42d1454f361c29117095ebe1989437db390

View File

@ -5,12 +5,12 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include) include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include) include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc) add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc)
add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc) add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc platformUtils.cc)
add_executable(list list.cc platformUtils Datasets.cc) add_executable(list list.cc platformUtils Datasets.cc)
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux") if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux")
target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX stdc++fs) target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp stdc++fs)
else() else()
target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX) target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp)
endif() endif()
target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}") target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}")

View File

@ -1,10 +1,22 @@
#include <sstream> #include <sstream>
#include <locale> #include <locale>
#include "Datasets.h"
#include "ReportBase.h" #include "ReportBase.h"
#include "BestResult.h" #include "BestResult.h"
namespace platform { namespace platform {
ReportBase::ReportBase(json data_, bool compare) : data(data_), compare(compare), margin(0.1)
{
stringstream oss;
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
meaning = {
{Symbols::equal_best, "Equal to best"},
{Symbols::better_best, "Better than best"},
{Symbols::cross, "Less than or equal to ZeroR"},
{Symbols::upward_arrow, oss.str()}
};
}
string ReportBase::fromVector(const string& key) string ReportBase::fromVector(const string& key)
{ {
stringstream oss; stringstream oss;
@ -34,4 +46,62 @@ namespace platform {
header(); header();
body(); body();
} }
string ReportBase::compareResult(const string& dataset, double result)
{
string status = " ";
if (compare) {
double best = bestResult(dataset, data["model"].get<string>());
if (result == best) {
status = Symbols::equal_best;
} else if (result > best) {
status = Symbols::better_best;
}
} else {
if (data["score_name"].get<string>() == "accuracy") {
auto dt = Datasets(Paths::datasets(), false);
dt.loadDataset(dataset);
auto numClasses = dt.getNClasses(dataset);
if (numClasses == 2) {
vector<int> distribution = dt.getClassesCounts(dataset);
double nSamples = dt.getNSamples(dataset);
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
double mark = *maxValue / nSamples * (1 + margin);
if (mark > 1) {
mark = 0.9995;
}
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
}
}
}
if (status != " ") {
auto item = summary.find(status);
if (item != summary.end()) {
summary[status]++;
} else {
summary[status] = 1;
}
}
return status;
}
double ReportBase::bestResult(const string& dataset, const string& model)
{
double value = 0.0;
if (bestResults.size() == 0) {
// try to load the best results
string score = data["score_name"];
replace(score.begin(), score.end(), '_', '-');
string fileName = "best_results_" + score + "_" + model + ".json";
ifstream resultData(Paths::results() + "/" + fileName);
if (resultData.is_open()) {
bestResults = json::parse(resultData);
}
}
try {
value = bestResults.at(dataset).at(0);
}
catch (exception) {
value = 1.0;
}
return value;
}
} }

View File

@ -2,14 +2,26 @@
#define REPORTBASE_H #define REPORTBASE_H
#include <string> #include <string>
#include <iostream> #include <iostream>
#include "Paths.h"
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
using json = nlohmann::json; using json = nlohmann::json;
namespace platform { namespace platform {
using namespace std; using namespace std;
class Symbols {
public:
inline static const string check_mark{ "\u2714" };
inline static const string exclamation{ "\u2757" };
inline static const string black_star{ "\u2605" };
inline static const string cross{ "\u2717" };
inline static const string upward_arrow{ "\u27B6" };
inline static const string down_arrow{ "\u27B4" };
inline static const string equal_best{ check_mark };
inline static const string better_best{ black_star };
};
class ReportBase { class ReportBase {
public: public:
explicit ReportBase(json data_) { data = data_; }; explicit ReportBase(json data_, bool compare);
virtual ~ReportBase() = default; virtual ~ReportBase() = default;
void show(); void show();
protected: protected:
@ -18,6 +30,15 @@ namespace platform {
string fVector(const string& title, const json& data, const int width, const int precision); string fVector(const string& title, const json& data, const int width, const int precision);
virtual void header() = 0; virtual void header() = 0;
virtual void body() = 0; virtual void body() = 0;
virtual void showSummary() = 0;
string compareResult(const string& dataset, double result);
map<string, int> summary;
double margin;
map<string, string> meaning;
private:
double bestResult(const string& dataset, const string& model);
bool compare;
json bestResults;
}; };
}; };
#endif #endif

View File

@ -11,11 +11,11 @@ namespace platform {
string do_grouping() const { return "\03"; } string do_grouping() const { return "\03"; }
}; };
string ReportConsole::headerLine(const string& text) string ReportConsole::headerLine(const string& text, int utf = 0)
{ {
int n = MAXL - text.length() - 3; int n = MAXL - text.length() - 3;
n = n < 0 ? 0 : n; n = n < 0 ? 0 : n;
return "* " + text + string(n, ' ') + "*\n"; return "* " + text + string(n + utf, ' ') + "*\n";
} }
void ReportConsole::header() void ReportConsole::header()
@ -36,8 +36,8 @@ namespace platform {
} }
void ReportConsole::body() void ReportConsole::body()
{ {
cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl; cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
cout << "=== ============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl; cout << "=== ========================= ====== ===== === ========= ========= ========= =============== =================== ====================" << endl;
json lastResult; json lastResult;
double totalScore = 0.0; double totalScore = 0.0;
bool odd = true; bool odd = true;
@ -50,15 +50,17 @@ namespace platform {
auto color = odd ? Colors::CYAN() : Colors::BLUE(); auto color = odd ? Colors::CYAN() : Colors::BLUE();
cout << color; cout << color;
cout << setw(3) << index++ << " "; cout << setw(3) << index++ << " ";
cout << setw(30) << left << r["dataset"].get<string>() << " "; cout << setw(25) << left << r["dataset"].get<string>() << " ";
cout << setw(6) << right << r["samples"].get<int>() << " "; cout << setw(6) << right << r["samples"].get<int>() << " ";
cout << setw(5) << right << r["features"].get<int>() << " "; cout << setw(5) << right << r["features"].get<int>() << " ";
cout << setw(3) << right << r["classes"].get<int>() << " "; cout << setw(3) << right << r["classes"].get<int>() << " ";
cout << setw(9) << setprecision(2) << fixed << r["nodes"].get<float>() << " "; cout << setw(9) << setprecision(2) << fixed << r["nodes"].get<float>() << " ";
cout << setw(9) << setprecision(2) << fixed << r["leaves"].get<float>() << " "; cout << setw(9) << setprecision(2) << fixed << r["leaves"].get<float>() << " ";
cout << setw(9) << setprecision(2) << fixed << r["depth"].get<float>() << " "; cout << setw(9) << setprecision(2) << fixed << r["depth"].get<float>() << " ";
cout << setw(8) << right << setprecision(6) << fixed << r["score"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get<double>() << " "; cout << setw(8) << right << setprecision(6) << fixed << r["score"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get<double>();
cout << setw(11) << right << setprecision(6) << fixed << r["time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get<double>() << " "; const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>());
cout << status;
cout << setw(12) << right << setprecision(6) << fixed << r["time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get<double>() << " ";
try { try {
cout << r["hyperparameters"].get<string>(); cout << r["hyperparameters"].get<string>();
} }
@ -81,9 +83,21 @@ namespace platform {
footer(totalScore); footer(totalScore);
} }
} }
void ReportConsole::showSummary()
{
for (const auto& item : summary) {
stringstream oss;
oss << setw(3) << left << item.first;
oss << setw(3) << right << item.second << " ";
oss << left << meaning.at(item.first);
cout << headerLine(oss.str(), 2);
}
}
void ReportConsole::footer(double totalScore) void ReportConsole::footer(double totalScore)
{ {
cout << Colors::MAGENTA() << string(MAXL, '*') << endl; cout << Colors::MAGENTA() << string(MAXL, '*') << endl;
showSummary();
auto score = data["score_name"].get<string>(); auto score = data["score_name"].get<string>();
if (score == BestResult::scoreName()) { if (score == BestResult::scoreName()) {
stringstream oss; stringstream oss;

View File

@ -7,17 +7,18 @@
namespace platform { namespace platform {
using namespace std; using namespace std;
const int MAXL = 132; const int MAXL = 133;
class ReportConsole : public ReportBase { class ReportConsole : public ReportBase {
public: public:
explicit ReportConsole(json data_, int index = -1) : ReportBase(data_), selectedIndex(index) {}; explicit ReportConsole(json data_, bool compare = false, int index = -1) : ReportBase(data_, compare), selectedIndex(index) {};
virtual ~ReportConsole() = default; virtual ~ReportConsole() = default;
private: private:
int selectedIndex; int selectedIndex;
string headerLine(const string& text); string headerLine(const string& text, int utf);
void header() override; void header() override;
void body() override; void body() override;
void footer(double totalScore); void footer(double totalScore);
void showSummary();
}; };
}; };
#endif #endif

View File

@ -13,17 +13,195 @@ namespace platform {
string do_grouping() const { return "\03"; } string do_grouping() const { return "\03"; }
}; };
ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook) : ReportBase(data_, compare), row(0), workbook(workbook)
{
normalSize = 14; //font size for report body
colorTitle = 0xB1A0C7;
colorOdd = 0xDCE6F1;
colorEven = 0xFDE9D9;
createFile();
}
lxw_workbook* ReportExcel::getWorkbook()
{
return workbook;
}
lxw_format* ReportExcel::efectiveStyle(const string& style)
{
lxw_format* efectiveStyle;
if (style == "") {
efectiveStyle = NULL;
} else {
string suffix = row % 2 ? "_odd" : "_even";
efectiveStyle = styles.at(style + suffix);
}
return efectiveStyle;
}
void ReportExcel::writeString(int row, int col, const string& text, const string& style)
{
worksheet_write_string(worksheet, row, col, text.c_str(), efectiveStyle(style));
}
void ReportExcel::writeInt(int row, int col, const int number, const string& style)
{
worksheet_write_number(worksheet, row, col, number, efectiveStyle(style));
}
void ReportExcel::writeDouble(int row, int col, const double number, const string& style)
{
worksheet_write_number(worksheet, row, col, number, efectiveStyle(style));
}
void ReportExcel::formatColumns()
{
worksheet_freeze_panes(worksheet, 6, 1);
vector<int> columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 15, 12, 23 };
for (int i = 0; i < columns_sizes.size(); ++i) {
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
}
}
void ReportExcel::addColor(lxw_format* style, bool odd)
{
uint32_t efectiveColor = odd ? colorEven : colorOdd;
format_set_bg_color(style, lxw_color_t(efectiveColor));
}
void ReportExcel::createStyle(const string& name, lxw_format* style, bool odd)
{
addColor(style, odd);
if (name == "textCentered") {
format_set_align(style, LXW_ALIGN_CENTER);
format_set_font_size(style, normalSize);
format_set_border(style, LXW_BORDER_THIN);
} else if (name == "text") {
format_set_font_size(style, normalSize);
format_set_border(style, LXW_BORDER_THIN);
} else if (name == "bodyHeader") {
format_set_bold(style);
format_set_font_size(style, normalSize);
format_set_align(style, LXW_ALIGN_CENTER);
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
format_set_border(style, LXW_BORDER_THIN);
format_set_bg_color(style, lxw_color_t(colorTitle));
} else if (name == "result") {
format_set_font_size(style, normalSize);
format_set_border(style, LXW_BORDER_THIN);
format_set_num_format(style, "0.0000000");
} else if (name == "time") {
format_set_font_size(style, normalSize);
format_set_border(style, LXW_BORDER_THIN);
format_set_num_format(style, "#,##0.000000");
} else if (name == "ints") {
format_set_font_size(style, normalSize);
format_set_num_format(style, "###,##0");
format_set_border(style, LXW_BORDER_THIN);
} else if (name == "floats") {
format_set_border(style, LXW_BORDER_THIN);
format_set_font_size(style, normalSize);
format_set_num_format(style, "#,##0.00");
}
}
void ReportExcel::createFormats()
{
auto styleNames = { "text", "textCentered", "bodyHeader", "result", "time", "ints", "floats" };
lxw_format* style;
for (string name : styleNames) {
lxw_format* style = workbook_add_format(workbook);
style = workbook_add_format(workbook);
createStyle(name, style, true);
styles[name + "_odd"] = style;
style = workbook_add_format(workbook);
createStyle(name, style, false);
styles[name + "_even"] = style;
}
// Header 1st line
lxw_format* headerFirst = workbook_add_format(workbook);
format_set_bold(headerFirst);
format_set_font_size(headerFirst, 18);
format_set_align(headerFirst, LXW_ALIGN_CENTER);
format_set_align(headerFirst, LXW_ALIGN_VERTICAL_CENTER);
format_set_border(headerFirst, LXW_BORDER_THIN);
format_set_bg_color(headerFirst, lxw_color_t(colorTitle));
// Header rest
lxw_format* headerRest = workbook_add_format(workbook);
format_set_bold(headerRest);
format_set_align(headerRest, LXW_ALIGN_CENTER);
format_set_font_size(headerRest, 16);
format_set_align(headerRest, LXW_ALIGN_VERTICAL_CENTER);
format_set_border(headerRest, LXW_BORDER_THIN);
format_set_bg_color(headerRest, lxw_color_t(colorOdd));
// Header small
lxw_format* headerSmall = workbook_add_format(workbook);
format_set_bold(headerSmall);
format_set_align(headerSmall, LXW_ALIGN_LEFT);
format_set_font_size(headerSmall, 12);
format_set_border(headerSmall, LXW_BORDER_THIN);
format_set_align(headerSmall, LXW_ALIGN_VERTICAL_CENTER);
format_set_bg_color(headerSmall, lxw_color_t(colorOdd));
// Summary style
lxw_format* summaryStyle = workbook_add_format(workbook);
format_set_bold(summaryStyle);
format_set_font_size(summaryStyle, 16);
format_set_border(summaryStyle, LXW_BORDER_THIN);
format_set_align(summaryStyle, LXW_ALIGN_VERTICAL_CENTER);
styles["headerFirst"] = headerFirst;
styles["headerRest"] = headerRest;
styles["headerSmall"] = headerSmall;
styles["summaryStyle"] = summaryStyle;
}
void ReportExcel::setProperties()
{
char line[data["title"].get<string>().size() + 1];
strcpy(line, data["title"].get<string>().c_str());
lxw_doc_properties properties = {
.title = line,
.subject = "Machine learning results",
.author = "Ricardo Montañana Gómez",
.manager = "Dr. J. A. Gámez, Dr. J. M. Puerta",
.company = "UCLM",
.comments = "Created with libxlsxwriter and c++",
};
workbook_set_properties(workbook, &properties);
}
void ReportExcel::createFile() void ReportExcel::createFile()
{ {
doc.create(Paths::excel() + "some_results.xlsx"); if (workbook == NULL) {
wks = doc.workbook().worksheet("Sheet1"); workbook = workbook_new((Paths::excel() + fileName).c_str());
wks.setName(data["model"].get<string>()); }
const string name = data["model"].get<string>();
string suffix = "";
string efectiveName;
int num = 1;
// Create a sheet with the name of the model
while (true) {
efectiveName = name + suffix;
if (workbook_get_worksheet_by_name(workbook, efectiveName.c_str())) {
suffix = to_string(++num);
} else {
worksheet = workbook_add_worksheet(workbook, efectiveName.c_str());
break;
}
if (num > 100) {
throw invalid_argument("Couldn't create sheet " + efectiveName);
}
}
cout << "Adding sheet " << efectiveName << " to " << Paths::excel() + fileName << endl;
setProperties();
createFormats();
formatColumns();
} }
void ReportExcel::closeFile() void ReportExcel::closeFile()
{ {
doc.save(); workbook_close(workbook);
doc.close();
} }
void ReportExcel::header() void ReportExcel::header()
@ -32,45 +210,62 @@ namespace platform {
locale::global(mylocale); locale::global(mylocale);
cout.imbue(mylocale); cout.imbue(mylocale);
stringstream oss; stringstream oss;
wks.cell("A1").value().set( string message = data["model"].get<string>() + " ver. " + data["version"].get<string>() + " " +
"Report " + data["model"].get<string>() + " ver. " + data["version"].get<string>() + " with " + data["language"].get<string>() + " ver. " + data["language_version"].get<string>() +
to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) + " with " + to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) +
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>()); " random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>();
wks.cell("A2").value() = data["title"].get<string>(); worksheet_merge_range(worksheet, 0, 0, 0, 12, message.c_str(), styles["headerFirst"]);
wks.cell("A3").value() = "Random seeds: " + fromVector("seeds") + " Stratified: " + worksheet_merge_range(worksheet, 1, 0, 1, 12, data["title"].get<string>().c_str(), styles["headerRest"]);
(data["stratified"].get<bool>() ? "True" : "False"); worksheet_merge_range(worksheet, 2, 0, 3, 0, ("Score is " + data["score_name"].get<string>()).c_str(), styles["headerRest"]);
oss << "Execution took " << setprecision(2) << fixed << data["duration"].get<float>() << " seconds, " worksheet_merge_range(worksheet, 2, 1, 3, 3, "Execution time", styles["headerRest"]);
<< data["duration"].get<float>() / 3600 << " hours, on " << data["platform"].get<string>(); oss << setprecision(2) << fixed << data["duration"].get<float>() << " s";
wks.cell("A4").value() = oss.str(); worksheet_merge_range(worksheet, 2, 4, 2, 5, oss.str().c_str(), styles["headerRest"]);
wks.cell("A5").value() = "Score is " + data["score_name"].get<string>(); oss.str("");
oss.clear();
oss << setprecision(2) << fixed << data["duration"].get<float>() / 3600 << " h";
worksheet_merge_range(worksheet, 3, 4, 3, 5, oss.str().c_str(), styles["headerRest"]);
worksheet_merge_range(worksheet, 2, 6, 3, 7, "Platform", styles["headerRest"]);
worksheet_merge_range(worksheet, 2, 8, 3, 9, data["platform"].get<string>().c_str(), styles["headerRest"]);
worksheet_merge_range(worksheet, 2, 10, 2, 12, ("Random seeds: " + fromVector("seeds")).c_str(), styles["headerSmall"]);
oss.str("");
oss.clear();
oss << "Stratified: " << (data["stratified"].get<bool>() ? "True" : "False");
worksheet_merge_range(worksheet, 3, 10, 3, 11, oss.str().c_str(), styles["headerSmall"]);
oss.str("");
oss.clear();
oss << "Discretized: " << (data["discretized"].get<bool>() ? "True" : "False");
worksheet_write_string(worksheet, 3, 12, oss.str().c_str(), styles["headerSmall"]);
} }
void ReportExcel::body() void ReportExcel::body()
{ {
auto head = vector<string>( auto head = vector<string>(
{ "Dataset", "Samples", "Features", "Classes", "Nodes", "Edges", "States", "Score", "Score Std.", "Time", { "Dataset", "Samples", "Features", "Classes", "Nodes", "Edges", "States", "Score", "Score Std.", "St.", "Time",
"Time Std.", "Hyperparameters" }); "Time Std.", "Hyperparameters" });
int col = 1; int col = 0;
for (const auto& item : head) { for (const auto& item : head) {
wks.cell(8, col++).value() = item; writeString(5, col++, item, "bodyHeader");
} }
int row = 9; row = 6;
col = 1; col = 0;
int hypSize = 22;
json lastResult; json lastResult;
double totalScore = 0.0; double totalScore = 0.0;
string hyperparameters; string hyperparameters;
for (const auto& r : data["results"]) { for (const auto& r : data["results"]) {
wks.cell(row, col).value() = r["dataset"].get<string>(); writeString(row, col, r["dataset"].get<string>(), "text");
wks.cell(row, col + 1).value() = r["samples"].get<int>(); writeInt(row, col + 1, r["samples"].get<int>(), "ints");
wks.cell(row, col + 2).value() = r["features"].get<int>(); writeInt(row, col + 2, r["features"].get<int>(), "ints");
wks.cell(row, col + 3).value() = r["classes"].get<int>(); writeInt(row, col + 3, r["classes"].get<int>(), "ints");
wks.cell(row, col + 4).value() = r["nodes"].get<float>(); writeDouble(row, col + 4, r["nodes"].get<float>(), "floats");
wks.cell(row, col + 5).value() = r["leaves"].get<float>(); writeDouble(row, col + 5, r["leaves"].get<float>(), "floats");
wks.cell(row, col + 6).value() = r["depth"].get<float>(); writeDouble(row, col + 6, r["depth"].get<double>(), "floats");
wks.cell(row, col + 7).value() = r["score"].get<double>(); writeDouble(row, col + 7, r["score"].get<double>(), "result");
wks.cell(row, col + 8).value() = r["score_std"].get<double>(); writeDouble(row, col + 8, r["score_std"].get<double>(), "result");
wks.cell(row, col + 9).value() = r["time"].get<double>(); const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>());
wks.cell(row, col + 10).value() = r["time_std"].get<double>(); writeString(row, col + 9, status, "textCentered");
writeDouble(row, col + 10, r["time"].get<double>(), "time");
writeDouble(row, col + 11, r["time_std"].get<double>(), "time");
try { try {
hyperparameters = r["hyperparameters"].get<string>(); hyperparameters = r["hyperparameters"].get<string>();
} }
@ -79,31 +274,57 @@ namespace platform {
oss << r["hyperparameters"]; oss << r["hyperparameters"];
hyperparameters = oss.str(); hyperparameters = oss.str();
} }
wks.cell(row, col + 11).value() = hyperparameters; if (hyperparameters.size() > hypSize) {
hypSize = hyperparameters.size();
}
writeString(row, col + 12, hyperparameters, "text");
lastResult = r; lastResult = r;
totalScore += r["score"].get<double>(); totalScore += r["score"].get<double>();
row++; row++;
} }
// Set the right column width of hyperparameters with the maximum length
worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL);
// Show totals if only one dataset is present in the result
if (data["results"].size() == 1) { if (data["results"].size() == 1) {
for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) { for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
row++; row++;
col = 1; col = 1;
wks.cell(row, col).value() = group; writeString(row, col, group, "text");
for (double item : lastResult[group]) { for (double item : lastResult[group]) {
wks.cell(row, ++col).value() = item; string style = group.find("scores") != string::npos ? "result" : "time";
writeDouble(row, ++col, item, style);
} }
} }
// Set with of columns to show those totals completely
worksheet_set_column(worksheet, 1, 1, 12, NULL);
for (int i = 2; i < 7; ++i) {
// doesn't work with from col to col, so...
worksheet_set_column(worksheet, i, i, 15, NULL);
}
} else { } else {
footer(totalScore, row); footer(totalScore, row);
} }
} }
void ReportExcel::showSummary()
{
for (const auto& item : summary) {
worksheet_write_string(worksheet, row + 2, 1, item.first.c_str(), styles["summaryStyle"]);
worksheet_write_number(worksheet, row + 2, 2, item.second, styles["summaryStyle"]);
worksheet_merge_range(worksheet, row + 2, 3, row + 2, 5, meaning.at(item.first).c_str(), styles["summaryStyle"]);
row += 1;
}
}
void ReportExcel::footer(double totalScore, int row) void ReportExcel::footer(double totalScore, int row)
{ {
showSummary();
row += 4 + summary.size();
auto score = data["score_name"].get<string>(); auto score = data["score_name"].get<string>();
if (score == BestResult::scoreName()) { if (score == BestResult::scoreName()) {
wks.cell(row + 2, 1).value() = score + " compared to " + BestResult::title() + " .: "; worksheet_merge_range(worksheet, row, 1, row, 5, (score + " compared to " + BestResult::title() + " .:").c_str(), efectiveStyle("text"));
wks.cell(row + 2, 5).value() = totalScore / BestResult::score(); writeDouble(row, 6, totalScore / BestResult::score(), "result");
} }
} }
} }

View File

@ -1,25 +1,42 @@
#ifndef REPORTEXCEL_H #ifndef REPORTEXCEL_H
#define REPORTEXCEL_H #define REPORTEXCEL_H
#include <OpenXLSX.hpp> #include<map>
#include "xlsxwriter.h"
#include "ReportBase.h" #include "ReportBase.h"
#include "Paths.h"
#include "Colors.h" #include "Colors.h"
namespace platform { namespace platform {
using namespace std; using namespace std;
using namespace OpenXLSX;
const int MAXLL = 128; const int MAXLL = 128;
class ReportExcel : public ReportBase{
class ReportExcel : public ReportBase {
public: public:
explicit ReportExcel(json data_) : ReportBase(data_) {createFile();}; explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook);
virtual ~ReportExcel() {closeFile();}; lxw_workbook* getWorkbook();
private: private:
void writeString(int row, int col, const string& text, const string& style = "");
void writeInt(int row, int col, const int number, const string& style = "");
void writeDouble(int row, int col, const double number, const string& style = "");
void formatColumns();
void createFormats();
void setProperties();
void createFile(); void createFile();
void closeFile(); void closeFile();
XLDocument doc; void showSummary();
XLWorksheet wks; lxw_workbook* workbook;
lxw_worksheet* worksheet;
map<string, lxw_format*> styles;
int row;
int normalSize; //font size for report body
uint32_t colorTitle;
uint32_t colorOdd;
uint32_t colorEven;
const string fileName = "some_results.xlsx";
void header() override; void header() override;
void body() override; void body() override;
void footer(double totalScore, int row); void footer(double totalScore, int row);
void createStyle(const string& name, lxw_format* style, bool odd);
void addColor(lxw_format* style, bool odd);
lxw_format* efectiveStyle(const string& name);
}; };
}; };
#endif // !REPORTEXCEL_H #endif // !REPORTEXCEL_H

View File

@ -104,15 +104,17 @@ namespace platform {
cout << "Invalid index" << endl; cout << "Invalid index" << endl;
return -1; return -1;
} }
void Results::report(const int index, const bool excelReport) const void Results::report(const int index, const bool excelReport)
{ {
cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl; cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl;
auto data = files.at(index).load(); auto data = files.at(index).load();
if (excelReport) { if (excelReport) {
ReportExcel reporter(data); ReportExcel reporter(data, compare, workbook);
reporter.show(); reporter.show();
openExcel = true;
workbook = reporter.getWorkbook();
} else { } else {
ReportConsole reporter(data); ReportConsole reporter(data, compare);
reporter.show(); reporter.show();
} }
} }
@ -124,7 +126,7 @@ namespace platform {
return; return;
} }
cout << Colors::YELLOW() << "Showing " << files.at(index).getFilename() << endl; cout << Colors::YELLOW() << "Showing " << files.at(index).getFilename() << endl;
ReportConsole reporter(data, idx); ReportConsole reporter(data, compare, idx);
reporter.show(); reporter.show();
} }
void Results::menu() void Results::menu()
@ -132,9 +134,21 @@ namespace platform {
char option; char option;
int index; int index;
bool finished = false; bool finished = false;
string color, context;
string filename, line, options = "qldhsre"; string filename, line, options = "qldhsre";
while (!finished) { while (!finished) {
cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): "; if (indexList) {
color = Colors::GREEN();
context = " (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): ";
options = "qldhsre";
} else {
color = Colors::MAGENTA();
context = " (quit='q', list='l'): ";
options = "ql";
}
cout << Colors::RESET() << color;
cout << "Choose option " << context;
getline(cin, line); getline(cin, line);
if (line.size() == 0) if (line.size() == 0)
continue; continue;
@ -148,6 +162,7 @@ namespace platform {
if (all_of(line.begin(), line.end(), ::isdigit)) { if (all_of(line.begin(), line.end(), ::isdigit)) {
int idx = stoi(line); int idx = stoi(line);
if (indexList) { if (indexList) {
// The value is about the files list
index = idx; index = idx;
if (index >= 0 && index < files.size()) { if (index >= 0 && index < files.size()) {
report(index, false); report(index, false);
@ -155,6 +170,7 @@ namespace platform {
continue; continue;
} }
} else { } else {
// The value is about the result showed on screen
showIndex(index, idx); showIndex(index, idx);
continue; continue;
} }
@ -281,6 +297,9 @@ namespace platform {
sortDate(); sortDate();
show(); show();
menu(); menu();
if (openExcel) {
workbook_close(workbook);
}
cout << "Done!" << endl; cout << "Done!" << endl;
} }

View File

@ -1,5 +1,6 @@
#ifndef RESULTS_H #ifndef RESULTS_H
#define RESULTS_H #define RESULTS_H
#include "xlsxwriter.h"
#include <map> #include <map>
#include <vector> #include <vector>
#include <string> #include <string>
@ -34,7 +35,11 @@ namespace platform {
}; };
class Results { class Results {
public: public:
Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial) : path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial) { load(); }; Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial, bool compare) :
path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial), compare(compare)
{
load();
};
void manage(); void manage();
private: private:
string path; string path;
@ -44,10 +49,13 @@ namespace platform {
bool complete; bool complete;
bool partial; bool partial;
bool indexList = true; bool indexList = true;
bool openExcel = false;
bool compare;
lxw_workbook* workbook = NULL;
vector<Result> files; vector<Result> files;
void load(); // Loads the list of results void load(); // Loads the list of results
void show() const; void show() const;
void report(const int index, const bool excelReport) const; void report(const int index, const bool excelReport);
void showIndex(const int index, const int idx) const; void showIndex(const int index, const int idx) const;
int getIndex(const string& intent) const; int getIndex(const string& intent) const;
void menu(); void menu();

View File

@ -87,7 +87,7 @@ int main(int argc, char** argv)
auto stratified = program.get<bool>("stratified"); auto stratified = program.get<bool>("stratified");
auto n_folds = program.get<int>("folds"); auto n_folds = program.get<int>("folds");
auto seeds = program.get<vector<int>>("seeds"); auto seeds = program.get<vector<int>>("seeds");
auto hyperparameters =program.get<string>("hyperparameters"); auto hyperparameters = program.get<string>("hyperparameters");
vector<string> filesToTest; vector<string> filesToTest;
auto datasets = platform::Datasets(path, true, platform::ARFF); auto datasets = platform::Datasets(path, true, platform::ARFF);
auto title = program.get<string>("title"); auto title = program.get<string>("title");
@ -102,7 +102,7 @@ int main(int argc, char** argv)
} }
filesToTest.push_back(file_name); filesToTest.push_back(file_name);
} else { } else {
filesToTest = platform::Datasets(path, true, platform::ARFF).getNames(); filesToTest = datasets.getNames();
saveResults = true; saveResults = true;
} }
/* /*

View File

@ -14,6 +14,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied"); program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true); program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true);
program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true); program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true);
program.add_argument("--compare").help("Compare with best results").default_value(false).implicit_value(true);
try { try {
program.parse_args(argc, argv); program.parse_args(argc, argv);
auto number = program.get<int>("number"); auto number = program.get<int>("number");
@ -24,6 +25,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
auto score = program.get<string>("score"); auto score = program.get<string>("score");
auto complete = program.get<bool>("complete"); auto complete = program.get<bool>("complete");
auto partial = program.get<bool>("partial"); auto partial = program.get<bool>("partial");
auto compare = program.get<bool>("compare");
} }
catch (const exception& err) { catch (const exception& err) {
cerr << err.what() << endl; cerr << err.what() << endl;
@ -41,9 +43,10 @@ int main(int argc, char** argv)
auto score = program.get<string>("score"); auto score = program.get<string>("score");
auto complete = program.get<bool>("complete"); auto complete = program.get<bool>("complete");
auto partial = program.get<bool>("partial"); auto partial = program.get<bool>("partial");
auto compare = program.get<bool>("compare");
if (complete) if (complete)
partial = false; partial = false;
auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial); auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial, compare);
results.manage(); results.manage();
return 0; return 0;
} }