Merge pull request 'Exchange OpenXLSX to libxlsxwriter' (#8) from libxlsxwriter into main
Add multiple sheets to excel file Add format and color to sheets Add comparison with ZeroR Add comparison with Best Results Separate contextual menu from general in manage
This commit is contained in:
commit
fc81730dfc
1
.gitignore
vendored
1
.gitignore
vendored
@ -36,3 +36,4 @@ build/
|
||||
cmake-build*/**
|
||||
.idea
|
||||
puml/**
|
||||
.vscode/settings.json
|
||||
|
6
.gitmodules
vendored
6
.gitmodules
vendored
@ -10,6 +10,6 @@
|
||||
[submodule "lib/json"]
|
||||
path = lib/json
|
||||
url = https://github.com/nlohmann/json.git
|
||||
[submodule "lib/openXLSX"]
|
||||
path = lib/openXLSX
|
||||
url = https://github.com/troldal/OpenXLSX.git
|
||||
[submodule "lib/libxlsxwriter"]
|
||||
path = lib/libxlsxwriter
|
||||
url = https://github.com/jmcnamara/libxlsxwriter.git
|
||||
|
109
.vscode/settings.json
vendored
109
.vscode/settings.json
vendored
@ -1,109 +0,0 @@
|
||||
{
|
||||
"files.associations": {
|
||||
"*.rmd": "markdown",
|
||||
"*.py": "python",
|
||||
"vector": "cpp",
|
||||
"__bit_reference": "cpp",
|
||||
"__bits": "cpp",
|
||||
"__config": "cpp",
|
||||
"__debug": "cpp",
|
||||
"__errc": "cpp",
|
||||
"__hash_table": "cpp",
|
||||
"__locale": "cpp",
|
||||
"__mutex_base": "cpp",
|
||||
"__node_handle": "cpp",
|
||||
"__nullptr": "cpp",
|
||||
"__split_buffer": "cpp",
|
||||
"__string": "cpp",
|
||||
"__threading_support": "cpp",
|
||||
"__tuple": "cpp",
|
||||
"array": "cpp",
|
||||
"atomic": "cpp",
|
||||
"bitset": "cpp",
|
||||
"cctype": "cpp",
|
||||
"chrono": "cpp",
|
||||
"clocale": "cpp",
|
||||
"cmath": "cpp",
|
||||
"compare": "cpp",
|
||||
"complex": "cpp",
|
||||
"concepts": "cpp",
|
||||
"cstdarg": "cpp",
|
||||
"cstddef": "cpp",
|
||||
"cstdint": "cpp",
|
||||
"cstdio": "cpp",
|
||||
"cstdlib": "cpp",
|
||||
"cstring": "cpp",
|
||||
"ctime": "cpp",
|
||||
"cwchar": "cpp",
|
||||
"cwctype": "cpp",
|
||||
"exception": "cpp",
|
||||
"initializer_list": "cpp",
|
||||
"ios": "cpp",
|
||||
"iosfwd": "cpp",
|
||||
"istream": "cpp",
|
||||
"limits": "cpp",
|
||||
"locale": "cpp",
|
||||
"memory": "cpp",
|
||||
"mutex": "cpp",
|
||||
"new": "cpp",
|
||||
"optional": "cpp",
|
||||
"ostream": "cpp",
|
||||
"ratio": "cpp",
|
||||
"sstream": "cpp",
|
||||
"stdexcept": "cpp",
|
||||
"streambuf": "cpp",
|
||||
"string": "cpp",
|
||||
"string_view": "cpp",
|
||||
"system_error": "cpp",
|
||||
"tuple": "cpp",
|
||||
"type_traits": "cpp",
|
||||
"typeinfo": "cpp",
|
||||
"unordered_map": "cpp",
|
||||
"variant": "cpp",
|
||||
"algorithm": "cpp",
|
||||
"iostream": "cpp",
|
||||
"iomanip": "cpp",
|
||||
"numeric": "cpp",
|
||||
"set": "cpp",
|
||||
"__tree": "cpp",
|
||||
"deque": "cpp",
|
||||
"list": "cpp",
|
||||
"map": "cpp",
|
||||
"unordered_set": "cpp",
|
||||
"any": "cpp",
|
||||
"condition_variable": "cpp",
|
||||
"forward_list": "cpp",
|
||||
"fstream": "cpp",
|
||||
"stack": "cpp",
|
||||
"thread": "cpp",
|
||||
"__memory": "cpp",
|
||||
"filesystem": "cpp",
|
||||
"*.toml": "toml",
|
||||
"utility": "cpp",
|
||||
"__verbose_abort": "cpp",
|
||||
"bit": "cpp",
|
||||
"random": "cpp",
|
||||
"*.tcc": "cpp",
|
||||
"functional": "cpp",
|
||||
"iterator": "cpp",
|
||||
"memory_resource": "cpp",
|
||||
"format": "cpp",
|
||||
"valarray": "cpp",
|
||||
"regex": "cpp",
|
||||
"span": "cpp",
|
||||
"cfenv": "cpp",
|
||||
"cinttypes": "cpp",
|
||||
"csetjmp": "cpp",
|
||||
"future": "cpp",
|
||||
"queue": "cpp",
|
||||
"typeindex": "cpp",
|
||||
"shared_mutex": "cpp",
|
||||
"*.ipp": "cpp",
|
||||
"cassert": "cpp",
|
||||
"charconv": "cpp",
|
||||
"source_location": "cpp",
|
||||
"ranges": "cpp"
|
||||
},
|
||||
"cmake.configureOnOpen": false,
|
||||
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
|
||||
}
|
@ -54,7 +54,6 @@ endif (ENABLE_CLANG_TIDY)
|
||||
add_git_submodule("lib/mdlp")
|
||||
add_git_submodule("lib/argparse")
|
||||
add_git_submodule("lib/json")
|
||||
add_git_submodule("lib/openXLSX")
|
||||
|
||||
# Subdirectories
|
||||
# --------------
|
||||
|
32
README.md
32
README.md
@ -2,4 +2,36 @@
|
||||
|
||||
Bayesian Network Classifier with libtorch from scratch
|
||||
|
||||
## 0. Setup
|
||||
|
||||
### libxlswriter
|
||||
|
||||
Before compiling BayesNet.
|
||||
|
||||
```bash
|
||||
cd lib/libxlsxwriter
|
||||
make
|
||||
sudo make install
|
||||
```
|
||||
|
||||
It has to be installed in /usr/local/lib otherwise CMakeLists.txt has to be modified accordingly
|
||||
|
||||
Environment variable has to be set:
|
||||
|
||||
```bash
|
||||
export LD_LIBRARY_PATH=/usr/local/lib
|
||||
```
|
||||
|
||||
### Release
|
||||
|
||||
```bash
|
||||
make release
|
||||
```
|
||||
|
||||
### Debug & Tests
|
||||
|
||||
```bash
|
||||
make debug
|
||||
```
|
||||
|
||||
## 1. Introduction
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit 4acc51828f7f93f3b2058a63f54d112af4034503
|
||||
Subproject commit 9c541ca72e7857dec71d8a41b97e42c2f1c92602
|
1
lib/libxlsxwriter
Submodule
1
lib/libxlsxwriter
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 44e72c5862f9d549453a4ff6e8ceab0da19705e5
|
@ -1 +0,0 @@
|
||||
Subproject commit b80da42d1454f361c29117095ebe1989437db390
|
@ -5,12 +5,12 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
||||
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
||||
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
||||
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc)
|
||||
add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc)
|
||||
add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc platformUtils.cc)
|
||||
add_executable(list list.cc platformUtils Datasets.cc)
|
||||
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
||||
if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux")
|
||||
target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX stdc++fs)
|
||||
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp stdc++fs)
|
||||
else()
|
||||
target_link_libraries(manage "${TORCH_LIBRARIES}" OpenXLSX::OpenXLSX)
|
||||
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp)
|
||||
endif()
|
||||
target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}")
|
@ -1,10 +1,22 @@
|
||||
#include <sstream>
|
||||
#include <locale>
|
||||
#include "Datasets.h"
|
||||
#include "ReportBase.h"
|
||||
#include "BestResult.h"
|
||||
|
||||
|
||||
namespace platform {
|
||||
ReportBase::ReportBase(json data_, bool compare) : data(data_), compare(compare), margin(0.1)
|
||||
{
|
||||
stringstream oss;
|
||||
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
|
||||
meaning = {
|
||||
{Symbols::equal_best, "Equal to best"},
|
||||
{Symbols::better_best, "Better than best"},
|
||||
{Symbols::cross, "Less than or equal to ZeroR"},
|
||||
{Symbols::upward_arrow, oss.str()}
|
||||
};
|
||||
}
|
||||
string ReportBase::fromVector(const string& key)
|
||||
{
|
||||
stringstream oss;
|
||||
@ -34,4 +46,62 @@ namespace platform {
|
||||
header();
|
||||
body();
|
||||
}
|
||||
string ReportBase::compareResult(const string& dataset, double result)
|
||||
{
|
||||
string status = " ";
|
||||
if (compare) {
|
||||
double best = bestResult(dataset, data["model"].get<string>());
|
||||
if (result == best) {
|
||||
status = Symbols::equal_best;
|
||||
} else if (result > best) {
|
||||
status = Symbols::better_best;
|
||||
}
|
||||
} else {
|
||||
if (data["score_name"].get<string>() == "accuracy") {
|
||||
auto dt = Datasets(Paths::datasets(), false);
|
||||
dt.loadDataset(dataset);
|
||||
auto numClasses = dt.getNClasses(dataset);
|
||||
if (numClasses == 2) {
|
||||
vector<int> distribution = dt.getClassesCounts(dataset);
|
||||
double nSamples = dt.getNSamples(dataset);
|
||||
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
|
||||
double mark = *maxValue / nSamples * (1 + margin);
|
||||
if (mark > 1) {
|
||||
mark = 0.9995;
|
||||
}
|
||||
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
|
||||
}
|
||||
}
|
||||
}
|
||||
if (status != " ") {
|
||||
auto item = summary.find(status);
|
||||
if (item != summary.end()) {
|
||||
summary[status]++;
|
||||
} else {
|
||||
summary[status] = 1;
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
double ReportBase::bestResult(const string& dataset, const string& model)
|
||||
{
|
||||
double value = 0.0;
|
||||
if (bestResults.size() == 0) {
|
||||
// try to load the best results
|
||||
string score = data["score_name"];
|
||||
replace(score.begin(), score.end(), '_', '-');
|
||||
string fileName = "best_results_" + score + "_" + model + ".json";
|
||||
ifstream resultData(Paths::results() + "/" + fileName);
|
||||
if (resultData.is_open()) {
|
||||
bestResults = json::parse(resultData);
|
||||
}
|
||||
}
|
||||
try {
|
||||
value = bestResults.at(dataset).at(0);
|
||||
}
|
||||
catch (exception) {
|
||||
value = 1.0;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
@ -2,14 +2,26 @@
|
||||
#define REPORTBASE_H
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include "Paths.h"
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
using json = nlohmann::json;
|
||||
namespace platform {
|
||||
using namespace std;
|
||||
class Symbols {
|
||||
public:
|
||||
inline static const string check_mark{ "\u2714" };
|
||||
inline static const string exclamation{ "\u2757" };
|
||||
inline static const string black_star{ "\u2605" };
|
||||
inline static const string cross{ "\u2717" };
|
||||
inline static const string upward_arrow{ "\u27B6" };
|
||||
inline static const string down_arrow{ "\u27B4" };
|
||||
inline static const string equal_best{ check_mark };
|
||||
inline static const string better_best{ black_star };
|
||||
};
|
||||
class ReportBase {
|
||||
public:
|
||||
explicit ReportBase(json data_) { data = data_; };
|
||||
explicit ReportBase(json data_, bool compare);
|
||||
virtual ~ReportBase() = default;
|
||||
void show();
|
||||
protected:
|
||||
@ -18,6 +30,15 @@ namespace platform {
|
||||
string fVector(const string& title, const json& data, const int width, const int precision);
|
||||
virtual void header() = 0;
|
||||
virtual void body() = 0;
|
||||
virtual void showSummary() = 0;
|
||||
string compareResult(const string& dataset, double result);
|
||||
map<string, int> summary;
|
||||
double margin;
|
||||
map<string, string> meaning;
|
||||
private:
|
||||
double bestResult(const string& dataset, const string& model);
|
||||
bool compare;
|
||||
json bestResults;
|
||||
};
|
||||
};
|
||||
#endif
|
@ -11,11 +11,11 @@ namespace platform {
|
||||
string do_grouping() const { return "\03"; }
|
||||
};
|
||||
|
||||
string ReportConsole::headerLine(const string& text)
|
||||
string ReportConsole::headerLine(const string& text, int utf = 0)
|
||||
{
|
||||
int n = MAXL - text.length() - 3;
|
||||
n = n < 0 ? 0 : n;
|
||||
return "* " + text + string(n, ' ') + "*\n";
|
||||
return "* " + text + string(n + utf, ' ') + "*\n";
|
||||
}
|
||||
|
||||
void ReportConsole::header()
|
||||
@ -36,8 +36,8 @@ namespace platform {
|
||||
}
|
||||
void ReportConsole::body()
|
||||
{
|
||||
cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
|
||||
cout << "=== ============================== ====== ===== === ========= ========= ========= =============== ================== ===============" << endl;
|
||||
cout << Colors::GREEN() << " # Dataset Sampl. Feat. Cls Nodes Edges States Score Time Hyperparameters" << endl;
|
||||
cout << "=== ========================= ====== ===== === ========= ========= ========= =============== =================== ====================" << endl;
|
||||
json lastResult;
|
||||
double totalScore = 0.0;
|
||||
bool odd = true;
|
||||
@ -50,15 +50,17 @@ namespace platform {
|
||||
auto color = odd ? Colors::CYAN() : Colors::BLUE();
|
||||
cout << color;
|
||||
cout << setw(3) << index++ << " ";
|
||||
cout << setw(30) << left << r["dataset"].get<string>() << " ";
|
||||
cout << setw(25) << left << r["dataset"].get<string>() << " ";
|
||||
cout << setw(6) << right << r["samples"].get<int>() << " ";
|
||||
cout << setw(5) << right << r["features"].get<int>() << " ";
|
||||
cout << setw(3) << right << r["classes"].get<int>() << " ";
|
||||
cout << setw(9) << setprecision(2) << fixed << r["nodes"].get<float>() << " ";
|
||||
cout << setw(9) << setprecision(2) << fixed << r["leaves"].get<float>() << " ";
|
||||
cout << setw(9) << setprecision(2) << fixed << r["depth"].get<float>() << " ";
|
||||
cout << setw(8) << right << setprecision(6) << fixed << r["score"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get<double>() << " ";
|
||||
cout << setw(11) << right << setprecision(6) << fixed << r["time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get<double>() << " ";
|
||||
cout << setw(8) << right << setprecision(6) << fixed << r["score"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["score_std"].get<double>();
|
||||
const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>());
|
||||
cout << status;
|
||||
cout << setw(12) << right << setprecision(6) << fixed << r["time"].get<double>() << "±" << setw(6) << setprecision(4) << fixed << r["time_std"].get<double>() << " ";
|
||||
try {
|
||||
cout << r["hyperparameters"].get<string>();
|
||||
}
|
||||
@ -81,9 +83,21 @@ namespace platform {
|
||||
footer(totalScore);
|
||||
}
|
||||
}
|
||||
void ReportConsole::showSummary()
|
||||
{
|
||||
for (const auto& item : summary) {
|
||||
stringstream oss;
|
||||
oss << setw(3) << left << item.first;
|
||||
oss << setw(3) << right << item.second << " ";
|
||||
oss << left << meaning.at(item.first);
|
||||
cout << headerLine(oss.str(), 2);
|
||||
}
|
||||
}
|
||||
|
||||
void ReportConsole::footer(double totalScore)
|
||||
{
|
||||
cout << Colors::MAGENTA() << string(MAXL, '*') << endl;
|
||||
showSummary();
|
||||
auto score = data["score_name"].get<string>();
|
||||
if (score == BestResult::scoreName()) {
|
||||
stringstream oss;
|
||||
|
@ -7,17 +7,18 @@
|
||||
|
||||
namespace platform {
|
||||
using namespace std;
|
||||
const int MAXL = 132;
|
||||
const int MAXL = 133;
|
||||
class ReportConsole : public ReportBase {
|
||||
public:
|
||||
explicit ReportConsole(json data_, int index = -1) : ReportBase(data_), selectedIndex(index) {};
|
||||
explicit ReportConsole(json data_, bool compare = false, int index = -1) : ReportBase(data_, compare), selectedIndex(index) {};
|
||||
virtual ~ReportConsole() = default;
|
||||
private:
|
||||
int selectedIndex;
|
||||
string headerLine(const string& text);
|
||||
string headerLine(const string& text, int utf);
|
||||
void header() override;
|
||||
void body() override;
|
||||
void footer(double totalScore);
|
||||
void showSummary();
|
||||
};
|
||||
};
|
||||
#endif
|
@ -13,17 +13,195 @@ namespace platform {
|
||||
string do_grouping() const { return "\03"; }
|
||||
};
|
||||
|
||||
ReportExcel::ReportExcel(json data_, bool compare, lxw_workbook* workbook) : ReportBase(data_, compare), row(0), workbook(workbook)
|
||||
{
|
||||
normalSize = 14; //font size for report body
|
||||
colorTitle = 0xB1A0C7;
|
||||
colorOdd = 0xDCE6F1;
|
||||
colorEven = 0xFDE9D9;
|
||||
createFile();
|
||||
}
|
||||
|
||||
lxw_workbook* ReportExcel::getWorkbook()
|
||||
{
|
||||
return workbook;
|
||||
}
|
||||
|
||||
lxw_format* ReportExcel::efectiveStyle(const string& style)
|
||||
{
|
||||
lxw_format* efectiveStyle;
|
||||
if (style == "") {
|
||||
efectiveStyle = NULL;
|
||||
} else {
|
||||
string suffix = row % 2 ? "_odd" : "_even";
|
||||
efectiveStyle = styles.at(style + suffix);
|
||||
}
|
||||
return efectiveStyle;
|
||||
}
|
||||
|
||||
void ReportExcel::writeString(int row, int col, const string& text, const string& style)
|
||||
{
|
||||
worksheet_write_string(worksheet, row, col, text.c_str(), efectiveStyle(style));
|
||||
}
|
||||
void ReportExcel::writeInt(int row, int col, const int number, const string& style)
|
||||
{
|
||||
worksheet_write_number(worksheet, row, col, number, efectiveStyle(style));
|
||||
}
|
||||
void ReportExcel::writeDouble(int row, int col, const double number, const string& style)
|
||||
{
|
||||
worksheet_write_number(worksheet, row, col, number, efectiveStyle(style));
|
||||
}
|
||||
|
||||
void ReportExcel::formatColumns()
|
||||
{
|
||||
worksheet_freeze_panes(worksheet, 6, 1);
|
||||
vector<int> columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 15, 12, 23 };
|
||||
for (int i = 0; i < columns_sizes.size(); ++i) {
|
||||
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
|
||||
}
|
||||
}
|
||||
|
||||
void ReportExcel::addColor(lxw_format* style, bool odd)
|
||||
{
|
||||
uint32_t efectiveColor = odd ? colorEven : colorOdd;
|
||||
format_set_bg_color(style, lxw_color_t(efectiveColor));
|
||||
}
|
||||
void ReportExcel::createStyle(const string& name, lxw_format* style, bool odd)
|
||||
{
|
||||
addColor(style, odd);
|
||||
if (name == "textCentered") {
|
||||
format_set_align(style, LXW_ALIGN_CENTER);
|
||||
format_set_font_size(style, normalSize);
|
||||
format_set_border(style, LXW_BORDER_THIN);
|
||||
} else if (name == "text") {
|
||||
format_set_font_size(style, normalSize);
|
||||
format_set_border(style, LXW_BORDER_THIN);
|
||||
} else if (name == "bodyHeader") {
|
||||
format_set_bold(style);
|
||||
format_set_font_size(style, normalSize);
|
||||
format_set_align(style, LXW_ALIGN_CENTER);
|
||||
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
|
||||
format_set_border(style, LXW_BORDER_THIN);
|
||||
format_set_bg_color(style, lxw_color_t(colorTitle));
|
||||
} else if (name == "result") {
|
||||
format_set_font_size(style, normalSize);
|
||||
format_set_border(style, LXW_BORDER_THIN);
|
||||
format_set_num_format(style, "0.0000000");
|
||||
} else if (name == "time") {
|
||||
format_set_font_size(style, normalSize);
|
||||
format_set_border(style, LXW_BORDER_THIN);
|
||||
format_set_num_format(style, "#,##0.000000");
|
||||
} else if (name == "ints") {
|
||||
format_set_font_size(style, normalSize);
|
||||
format_set_num_format(style, "###,##0");
|
||||
format_set_border(style, LXW_BORDER_THIN);
|
||||
} else if (name == "floats") {
|
||||
format_set_border(style, LXW_BORDER_THIN);
|
||||
format_set_font_size(style, normalSize);
|
||||
format_set_num_format(style, "#,##0.00");
|
||||
}
|
||||
}
|
||||
|
||||
void ReportExcel::createFormats()
|
||||
{
|
||||
auto styleNames = { "text", "textCentered", "bodyHeader", "result", "time", "ints", "floats" };
|
||||
lxw_format* style;
|
||||
for (string name : styleNames) {
|
||||
lxw_format* style = workbook_add_format(workbook);
|
||||
style = workbook_add_format(workbook);
|
||||
createStyle(name, style, true);
|
||||
styles[name + "_odd"] = style;
|
||||
style = workbook_add_format(workbook);
|
||||
createStyle(name, style, false);
|
||||
styles[name + "_even"] = style;
|
||||
}
|
||||
|
||||
// Header 1st line
|
||||
lxw_format* headerFirst = workbook_add_format(workbook);
|
||||
format_set_bold(headerFirst);
|
||||
format_set_font_size(headerFirst, 18);
|
||||
format_set_align(headerFirst, LXW_ALIGN_CENTER);
|
||||
format_set_align(headerFirst, LXW_ALIGN_VERTICAL_CENTER);
|
||||
format_set_border(headerFirst, LXW_BORDER_THIN);
|
||||
format_set_bg_color(headerFirst, lxw_color_t(colorTitle));
|
||||
|
||||
// Header rest
|
||||
lxw_format* headerRest = workbook_add_format(workbook);
|
||||
format_set_bold(headerRest);
|
||||
format_set_align(headerRest, LXW_ALIGN_CENTER);
|
||||
format_set_font_size(headerRest, 16);
|
||||
format_set_align(headerRest, LXW_ALIGN_VERTICAL_CENTER);
|
||||
format_set_border(headerRest, LXW_BORDER_THIN);
|
||||
format_set_bg_color(headerRest, lxw_color_t(colorOdd));
|
||||
|
||||
// Header small
|
||||
lxw_format* headerSmall = workbook_add_format(workbook);
|
||||
format_set_bold(headerSmall);
|
||||
format_set_align(headerSmall, LXW_ALIGN_LEFT);
|
||||
format_set_font_size(headerSmall, 12);
|
||||
format_set_border(headerSmall, LXW_BORDER_THIN);
|
||||
format_set_align(headerSmall, LXW_ALIGN_VERTICAL_CENTER);
|
||||
format_set_bg_color(headerSmall, lxw_color_t(colorOdd));
|
||||
|
||||
// Summary style
|
||||
lxw_format* summaryStyle = workbook_add_format(workbook);
|
||||
format_set_bold(summaryStyle);
|
||||
format_set_font_size(summaryStyle, 16);
|
||||
format_set_border(summaryStyle, LXW_BORDER_THIN);
|
||||
format_set_align(summaryStyle, LXW_ALIGN_VERTICAL_CENTER);
|
||||
|
||||
styles["headerFirst"] = headerFirst;
|
||||
styles["headerRest"] = headerRest;
|
||||
styles["headerSmall"] = headerSmall;
|
||||
styles["summaryStyle"] = summaryStyle;
|
||||
}
|
||||
|
||||
void ReportExcel::setProperties()
|
||||
{
|
||||
char line[data["title"].get<string>().size() + 1];
|
||||
strcpy(line, data["title"].get<string>().c_str());
|
||||
lxw_doc_properties properties = {
|
||||
.title = line,
|
||||
.subject = "Machine learning results",
|
||||
.author = "Ricardo Montañana Gómez",
|
||||
.manager = "Dr. J. A. Gámez, Dr. J. M. Puerta",
|
||||
.company = "UCLM",
|
||||
.comments = "Created with libxlsxwriter and c++",
|
||||
};
|
||||
workbook_set_properties(workbook, &properties);
|
||||
}
|
||||
|
||||
void ReportExcel::createFile()
|
||||
{
|
||||
doc.create(Paths::excel() + "some_results.xlsx");
|
||||
wks = doc.workbook().worksheet("Sheet1");
|
||||
wks.setName(data["model"].get<string>());
|
||||
if (workbook == NULL) {
|
||||
workbook = workbook_new((Paths::excel() + fileName).c_str());
|
||||
}
|
||||
const string name = data["model"].get<string>();
|
||||
string suffix = "";
|
||||
string efectiveName;
|
||||
int num = 1;
|
||||
// Create a sheet with the name of the model
|
||||
while (true) {
|
||||
efectiveName = name + suffix;
|
||||
if (workbook_get_worksheet_by_name(workbook, efectiveName.c_str())) {
|
||||
suffix = to_string(++num);
|
||||
} else {
|
||||
worksheet = workbook_add_worksheet(workbook, efectiveName.c_str());
|
||||
break;
|
||||
}
|
||||
if (num > 100) {
|
||||
throw invalid_argument("Couldn't create sheet " + efectiveName);
|
||||
}
|
||||
}
|
||||
cout << "Adding sheet " << efectiveName << " to " << Paths::excel() + fileName << endl;
|
||||
setProperties();
|
||||
createFormats();
|
||||
formatColumns();
|
||||
}
|
||||
|
||||
void ReportExcel::closeFile()
|
||||
{
|
||||
doc.save();
|
||||
doc.close();
|
||||
workbook_close(workbook);
|
||||
}
|
||||
|
||||
void ReportExcel::header()
|
||||
@ -32,45 +210,62 @@ namespace platform {
|
||||
locale::global(mylocale);
|
||||
cout.imbue(mylocale);
|
||||
stringstream oss;
|
||||
wks.cell("A1").value().set(
|
||||
"Report " + data["model"].get<string>() + " ver. " + data["version"].get<string>() + " with " +
|
||||
to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) +
|
||||
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>());
|
||||
wks.cell("A2").value() = data["title"].get<string>();
|
||||
wks.cell("A3").value() = "Random seeds: " + fromVector("seeds") + " Stratified: " +
|
||||
(data["stratified"].get<bool>() ? "True" : "False");
|
||||
oss << "Execution took " << setprecision(2) << fixed << data["duration"].get<float>() << " seconds, "
|
||||
<< data["duration"].get<float>() / 3600 << " hours, on " << data["platform"].get<string>();
|
||||
wks.cell("A4").value() = oss.str();
|
||||
wks.cell("A5").value() = "Score is " + data["score_name"].get<string>();
|
||||
string message = data["model"].get<string>() + " ver. " + data["version"].get<string>() + " " +
|
||||
data["language"].get<string>() + " ver. " + data["language_version"].get<string>() +
|
||||
" with " + to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) +
|
||||
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>();
|
||||
worksheet_merge_range(worksheet, 0, 0, 0, 12, message.c_str(), styles["headerFirst"]);
|
||||
worksheet_merge_range(worksheet, 1, 0, 1, 12, data["title"].get<string>().c_str(), styles["headerRest"]);
|
||||
worksheet_merge_range(worksheet, 2, 0, 3, 0, ("Score is " + data["score_name"].get<string>()).c_str(), styles["headerRest"]);
|
||||
worksheet_merge_range(worksheet, 2, 1, 3, 3, "Execution time", styles["headerRest"]);
|
||||
oss << setprecision(2) << fixed << data["duration"].get<float>() << " s";
|
||||
worksheet_merge_range(worksheet, 2, 4, 2, 5, oss.str().c_str(), styles["headerRest"]);
|
||||
oss.str("");
|
||||
oss.clear();
|
||||
oss << setprecision(2) << fixed << data["duration"].get<float>() / 3600 << " h";
|
||||
worksheet_merge_range(worksheet, 3, 4, 3, 5, oss.str().c_str(), styles["headerRest"]);
|
||||
worksheet_merge_range(worksheet, 2, 6, 3, 7, "Platform", styles["headerRest"]);
|
||||
worksheet_merge_range(worksheet, 2, 8, 3, 9, data["platform"].get<string>().c_str(), styles["headerRest"]);
|
||||
worksheet_merge_range(worksheet, 2, 10, 2, 12, ("Random seeds: " + fromVector("seeds")).c_str(), styles["headerSmall"]);
|
||||
oss.str("");
|
||||
oss.clear();
|
||||
oss << "Stratified: " << (data["stratified"].get<bool>() ? "True" : "False");
|
||||
worksheet_merge_range(worksheet, 3, 10, 3, 11, oss.str().c_str(), styles["headerSmall"]);
|
||||
oss.str("");
|
||||
oss.clear();
|
||||
oss << "Discretized: " << (data["discretized"].get<bool>() ? "True" : "False");
|
||||
worksheet_write_string(worksheet, 3, 12, oss.str().c_str(), styles["headerSmall"]);
|
||||
}
|
||||
|
||||
void ReportExcel::body()
|
||||
{
|
||||
auto head = vector<string>(
|
||||
{ "Dataset", "Samples", "Features", "Classes", "Nodes", "Edges", "States", "Score", "Score Std.", "Time",
|
||||
{ "Dataset", "Samples", "Features", "Classes", "Nodes", "Edges", "States", "Score", "Score Std.", "St.", "Time",
|
||||
"Time Std.", "Hyperparameters" });
|
||||
int col = 1;
|
||||
int col = 0;
|
||||
for (const auto& item : head) {
|
||||
wks.cell(8, col++).value() = item;
|
||||
writeString(5, col++, item, "bodyHeader");
|
||||
}
|
||||
int row = 9;
|
||||
col = 1;
|
||||
row = 6;
|
||||
col = 0;
|
||||
int hypSize = 22;
|
||||
json lastResult;
|
||||
double totalScore = 0.0;
|
||||
string hyperparameters;
|
||||
for (const auto& r : data["results"]) {
|
||||
wks.cell(row, col).value() = r["dataset"].get<string>();
|
||||
wks.cell(row, col + 1).value() = r["samples"].get<int>();
|
||||
wks.cell(row, col + 2).value() = r["features"].get<int>();
|
||||
wks.cell(row, col + 3).value() = r["classes"].get<int>();
|
||||
wks.cell(row, col + 4).value() = r["nodes"].get<float>();
|
||||
wks.cell(row, col + 5).value() = r["leaves"].get<float>();
|
||||
wks.cell(row, col + 6).value() = r["depth"].get<float>();
|
||||
wks.cell(row, col + 7).value() = r["score"].get<double>();
|
||||
wks.cell(row, col + 8).value() = r["score_std"].get<double>();
|
||||
wks.cell(row, col + 9).value() = r["time"].get<double>();
|
||||
wks.cell(row, col + 10).value() = r["time_std"].get<double>();
|
||||
writeString(row, col, r["dataset"].get<string>(), "text");
|
||||
writeInt(row, col + 1, r["samples"].get<int>(), "ints");
|
||||
writeInt(row, col + 2, r["features"].get<int>(), "ints");
|
||||
writeInt(row, col + 3, r["classes"].get<int>(), "ints");
|
||||
writeDouble(row, col + 4, r["nodes"].get<float>(), "floats");
|
||||
writeDouble(row, col + 5, r["leaves"].get<float>(), "floats");
|
||||
writeDouble(row, col + 6, r["depth"].get<double>(), "floats");
|
||||
writeDouble(row, col + 7, r["score"].get<double>(), "result");
|
||||
writeDouble(row, col + 8, r["score_std"].get<double>(), "result");
|
||||
const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>());
|
||||
writeString(row, col + 9, status, "textCentered");
|
||||
writeDouble(row, col + 10, r["time"].get<double>(), "time");
|
||||
writeDouble(row, col + 11, r["time_std"].get<double>(), "time");
|
||||
try {
|
||||
hyperparameters = r["hyperparameters"].get<string>();
|
||||
}
|
||||
@ -79,31 +274,57 @@ namespace platform {
|
||||
oss << r["hyperparameters"];
|
||||
hyperparameters = oss.str();
|
||||
}
|
||||
wks.cell(row, col + 11).value() = hyperparameters;
|
||||
if (hyperparameters.size() > hypSize) {
|
||||
hypSize = hyperparameters.size();
|
||||
}
|
||||
writeString(row, col + 12, hyperparameters, "text");
|
||||
lastResult = r;
|
||||
totalScore += r["score"].get<double>();
|
||||
row++;
|
||||
|
||||
}
|
||||
// Set the right column width of hyperparameters with the maximum length
|
||||
worksheet_set_column(worksheet, 12, 12, hypSize + 5, NULL);
|
||||
// Show totals if only one dataset is present in the result
|
||||
if (data["results"].size() == 1) {
|
||||
for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
|
||||
row++;
|
||||
col = 1;
|
||||
wks.cell(row, col).value() = group;
|
||||
writeString(row, col, group, "text");
|
||||
for (double item : lastResult[group]) {
|
||||
wks.cell(row, ++col).value() = item;
|
||||
string style = group.find("scores") != string::npos ? "result" : "time";
|
||||
writeDouble(row, ++col, item, style);
|
||||
}
|
||||
}
|
||||
// Set with of columns to show those totals completely
|
||||
worksheet_set_column(worksheet, 1, 1, 12, NULL);
|
||||
for (int i = 2; i < 7; ++i) {
|
||||
// doesn't work with from col to col, so...
|
||||
worksheet_set_column(worksheet, i, i, 15, NULL);
|
||||
}
|
||||
} else {
|
||||
footer(totalScore, row);
|
||||
}
|
||||
}
|
||||
|
||||
void ReportExcel::showSummary()
|
||||
{
|
||||
for (const auto& item : summary) {
|
||||
worksheet_write_string(worksheet, row + 2, 1, item.first.c_str(), styles["summaryStyle"]);
|
||||
worksheet_write_number(worksheet, row + 2, 2, item.second, styles["summaryStyle"]);
|
||||
worksheet_merge_range(worksheet, row + 2, 3, row + 2, 5, meaning.at(item.first).c_str(), styles["summaryStyle"]);
|
||||
row += 1;
|
||||
}
|
||||
}
|
||||
|
||||
void ReportExcel::footer(double totalScore, int row)
|
||||
{
|
||||
showSummary();
|
||||
row += 4 + summary.size();
|
||||
auto score = data["score_name"].get<string>();
|
||||
if (score == BestResult::scoreName()) {
|
||||
wks.cell(row + 2, 1).value() = score + " compared to " + BestResult::title() + " .: ";
|
||||
wks.cell(row + 2, 5).value() = totalScore / BestResult::score();
|
||||
worksheet_merge_range(worksheet, row, 1, row, 5, (score + " compared to " + BestResult::title() + " .:").c_str(), efectiveStyle("text"));
|
||||
writeDouble(row, 6, totalScore / BestResult::score(), "result");
|
||||
}
|
||||
}
|
||||
}
|
@ -1,25 +1,42 @@
|
||||
#ifndef REPORTEXCEL_H
|
||||
#define REPORTEXCEL_H
|
||||
#include <OpenXLSX.hpp>
|
||||
#include<map>
|
||||
#include "xlsxwriter.h"
|
||||
#include "ReportBase.h"
|
||||
#include "Paths.h"
|
||||
#include "Colors.h"
|
||||
namespace platform {
|
||||
using namespace std;
|
||||
using namespace OpenXLSX;
|
||||
const int MAXLL = 128;
|
||||
class ReportExcel : public ReportBase{
|
||||
|
||||
class ReportExcel : public ReportBase {
|
||||
public:
|
||||
explicit ReportExcel(json data_) : ReportBase(data_) {createFile();};
|
||||
virtual ~ReportExcel() {closeFile();};
|
||||
explicit ReportExcel(json data_, bool compare, lxw_workbook* workbook);
|
||||
lxw_workbook* getWorkbook();
|
||||
private:
|
||||
void writeString(int row, int col, const string& text, const string& style = "");
|
||||
void writeInt(int row, int col, const int number, const string& style = "");
|
||||
void writeDouble(int row, int col, const double number, const string& style = "");
|
||||
void formatColumns();
|
||||
void createFormats();
|
||||
void setProperties();
|
||||
void createFile();
|
||||
void closeFile();
|
||||
XLDocument doc;
|
||||
XLWorksheet wks;
|
||||
void showSummary();
|
||||
lxw_workbook* workbook;
|
||||
lxw_worksheet* worksheet;
|
||||
map<string, lxw_format*> styles;
|
||||
int row;
|
||||
int normalSize; //font size for report body
|
||||
uint32_t colorTitle;
|
||||
uint32_t colorOdd;
|
||||
uint32_t colorEven;
|
||||
const string fileName = "some_results.xlsx";
|
||||
void header() override;
|
||||
void body() override;
|
||||
void footer(double totalScore, int row);
|
||||
void createStyle(const string& name, lxw_format* style, bool odd);
|
||||
void addColor(lxw_format* style, bool odd);
|
||||
lxw_format* efectiveStyle(const string& name);
|
||||
};
|
||||
};
|
||||
#endif // !REPORTEXCEL_H
|
@ -104,15 +104,17 @@ namespace platform {
|
||||
cout << "Invalid index" << endl;
|
||||
return -1;
|
||||
}
|
||||
void Results::report(const int index, const bool excelReport) const
|
||||
void Results::report(const int index, const bool excelReport)
|
||||
{
|
||||
cout << Colors::YELLOW() << "Reporting " << files.at(index).getFilename() << endl;
|
||||
auto data = files.at(index).load();
|
||||
if (excelReport) {
|
||||
ReportExcel reporter(data);
|
||||
ReportExcel reporter(data, compare, workbook);
|
||||
reporter.show();
|
||||
openExcel = true;
|
||||
workbook = reporter.getWorkbook();
|
||||
} else {
|
||||
ReportConsole reporter(data);
|
||||
ReportConsole reporter(data, compare);
|
||||
reporter.show();
|
||||
}
|
||||
}
|
||||
@ -124,7 +126,7 @@ namespace platform {
|
||||
return;
|
||||
}
|
||||
cout << Colors::YELLOW() << "Showing " << files.at(index).getFilename() << endl;
|
||||
ReportConsole reporter(data, idx);
|
||||
ReportConsole reporter(data, compare, idx);
|
||||
reporter.show();
|
||||
}
|
||||
void Results::menu()
|
||||
@ -132,9 +134,21 @@ namespace platform {
|
||||
char option;
|
||||
int index;
|
||||
bool finished = false;
|
||||
string color, context;
|
||||
string filename, line, options = "qldhsre";
|
||||
while (!finished) {
|
||||
cout << Colors::RESET() << "Choose option (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): ";
|
||||
if (indexList) {
|
||||
color = Colors::GREEN();
|
||||
context = " (quit='q', list='l', delete='d', hide='h', sort='s', report='r', excel='e'): ";
|
||||
options = "qldhsre";
|
||||
} else {
|
||||
color = Colors::MAGENTA();
|
||||
context = " (quit='q', list='l'): ";
|
||||
options = "ql";
|
||||
}
|
||||
cout << Colors::RESET() << color;
|
||||
|
||||
cout << "Choose option " << context;
|
||||
getline(cin, line);
|
||||
if (line.size() == 0)
|
||||
continue;
|
||||
@ -148,6 +162,7 @@ namespace platform {
|
||||
if (all_of(line.begin(), line.end(), ::isdigit)) {
|
||||
int idx = stoi(line);
|
||||
if (indexList) {
|
||||
// The value is about the files list
|
||||
index = idx;
|
||||
if (index >= 0 && index < files.size()) {
|
||||
report(index, false);
|
||||
@ -155,6 +170,7 @@ namespace platform {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// The value is about the result showed on screen
|
||||
showIndex(index, idx);
|
||||
continue;
|
||||
}
|
||||
@ -281,6 +297,9 @@ namespace platform {
|
||||
sortDate();
|
||||
show();
|
||||
menu();
|
||||
if (openExcel) {
|
||||
workbook_close(workbook);
|
||||
}
|
||||
cout << "Done!" << endl;
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
#ifndef RESULTS_H
|
||||
#define RESULTS_H
|
||||
#include "xlsxwriter.h"
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
@ -34,7 +35,11 @@ namespace platform {
|
||||
};
|
||||
class Results {
|
||||
public:
|
||||
Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial) : path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial) { load(); };
|
||||
Results(const string& path, const int max, const string& model, const string& score, bool complete, bool partial, bool compare) :
|
||||
path(path), max(max), model(model), scoreName(score), complete(complete), partial(partial), compare(compare)
|
||||
{
|
||||
load();
|
||||
};
|
||||
void manage();
|
||||
private:
|
||||
string path;
|
||||
@ -44,10 +49,13 @@ namespace platform {
|
||||
bool complete;
|
||||
bool partial;
|
||||
bool indexList = true;
|
||||
bool openExcel = false;
|
||||
bool compare;
|
||||
lxw_workbook* workbook = NULL;
|
||||
vector<Result> files;
|
||||
void load(); // Loads the list of results
|
||||
void show() const;
|
||||
void report(const int index, const bool excelReport) const;
|
||||
void report(const int index, const bool excelReport);
|
||||
void showIndex(const int index, const int idx) const;
|
||||
int getIndex(const string& intent) const;
|
||||
void menu();
|
||||
|
@ -87,7 +87,7 @@ int main(int argc, char** argv)
|
||||
auto stratified = program.get<bool>("stratified");
|
||||
auto n_folds = program.get<int>("folds");
|
||||
auto seeds = program.get<vector<int>>("seeds");
|
||||
auto hyperparameters =program.get<string>("hyperparameters");
|
||||
auto hyperparameters = program.get<string>("hyperparameters");
|
||||
vector<string> filesToTest;
|
||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||
auto title = program.get<string>("title");
|
||||
@ -102,7 +102,7 @@ int main(int argc, char** argv)
|
||||
}
|
||||
filesToTest.push_back(file_name);
|
||||
} else {
|
||||
filesToTest = platform::Datasets(path, true, platform::ARFF).getNames();
|
||||
filesToTest = datasets.getNames();
|
||||
saveResults = true;
|
||||
}
|
||||
/*
|
||||
|
@ -14,6 +14,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
program.add_argument("-s", "--score").default_value("any").help("Filter results of the score name supplied");
|
||||
program.add_argument("--complete").help("Show only results with all datasets").default_value(false).implicit_value(true);
|
||||
program.add_argument("--partial").help("Show only partial results").default_value(false).implicit_value(true);
|
||||
program.add_argument("--compare").help("Compare with best results").default_value(false).implicit_value(true);
|
||||
try {
|
||||
program.parse_args(argc, argv);
|
||||
auto number = program.get<int>("number");
|
||||
@ -24,6 +25,7 @@ argparse::ArgumentParser manageArguments(int argc, char** argv)
|
||||
auto score = program.get<string>("score");
|
||||
auto complete = program.get<bool>("complete");
|
||||
auto partial = program.get<bool>("partial");
|
||||
auto compare = program.get<bool>("compare");
|
||||
}
|
||||
catch (const exception& err) {
|
||||
cerr << err.what() << endl;
|
||||
@ -41,9 +43,10 @@ int main(int argc, char** argv)
|
||||
auto score = program.get<string>("score");
|
||||
auto complete = program.get<bool>("complete");
|
||||
auto partial = program.get<bool>("partial");
|
||||
auto compare = program.get<bool>("compare");
|
||||
if (complete)
|
||||
partial = false;
|
||||
auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial);
|
||||
auto results = platform::Results(platform::Paths::results(), number, model, score, complete, partial, compare);
|
||||
results.manage();
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user