Complete comparison with ZeroR
This commit is contained in:
parent
1bdfbd1620
commit
f69f415b92
@ -5,12 +5,12 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
|||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
||||||
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
||||||
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc)
|
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc)
|
||||||
add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc)
|
add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc platformUtils.cc)
|
||||||
add_executable(list list.cc platformUtils Datasets.cc)
|
add_executable(list list.cc platformUtils Datasets.cc)
|
||||||
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
||||||
if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux")
|
if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux")
|
||||||
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so stdc++fs)
|
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp stdc++fs)
|
||||||
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so)
|
|
||||||
else()
|
else()
|
||||||
|
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp)
|
||||||
endif()
|
endif()
|
||||||
target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}")
|
target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}")
|
@ -1,5 +1,6 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <locale>
|
#include <locale>
|
||||||
|
#include "Datasets.h"
|
||||||
#include "ReportExcel.h"
|
#include "ReportExcel.h"
|
||||||
#include "BestResult.h"
|
#include "BestResult.h"
|
||||||
|
|
||||||
@ -12,6 +13,15 @@ namespace platform {
|
|||||||
|
|
||||||
string do_grouping() const { return "\03"; }
|
string do_grouping() const { return "\03"; }
|
||||||
};
|
};
|
||||||
|
ReportExcel::ReportExcel(json data_) : ReportBase(data_), row(0)
|
||||||
|
{
|
||||||
|
normalSize = 14; //font size for report body
|
||||||
|
colorTitle = 0xB1A0C7;
|
||||||
|
colorOdd = 0xDCE6F1;
|
||||||
|
colorEven = 0xFDE9D9;
|
||||||
|
margin = .1; // margin to add to ZeroR comparison
|
||||||
|
createFile();
|
||||||
|
}
|
||||||
|
|
||||||
lxw_format* ReportExcel::efectiveStyle(const string& style)
|
lxw_format* ReportExcel::efectiveStyle(const string& style)
|
||||||
{
|
{
|
||||||
@ -41,7 +51,7 @@ namespace platform {
|
|||||||
void ReportExcel::formatColumns()
|
void ReportExcel::formatColumns()
|
||||||
{
|
{
|
||||||
worksheet_freeze_panes(worksheet, 6, 1);
|
worksheet_freeze_panes(worksheet, 6, 1);
|
||||||
vector<int> columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 14, 12, 50 };
|
vector<int> columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 15, 12, 23 };
|
||||||
for (int i = 0; i < columns_sizes.size(); ++i) {
|
for (int i = 0; i < columns_sizes.size(); ++i) {
|
||||||
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
|
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
|
||||||
}
|
}
|
||||||
@ -65,9 +75,9 @@ namespace platform {
|
|||||||
} else if (name == "bodyHeader") {
|
} else if (name == "bodyHeader") {
|
||||||
format_set_bold(style);
|
format_set_bold(style);
|
||||||
format_set_font_size(style, normalSize);
|
format_set_font_size(style, normalSize);
|
||||||
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
|
|
||||||
format_set_align(style, LXW_ALIGN_CENTER);
|
format_set_align(style, LXW_ALIGN_CENTER);
|
||||||
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
|
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
|
||||||
|
format_set_border(style, LXW_BORDER_THIN);
|
||||||
format_set_bg_color(style, lxw_color_t(colorTitle));
|
format_set_bg_color(style, lxw_color_t(colorTitle));
|
||||||
} else if (name == "result") {
|
} else if (name == "result") {
|
||||||
format_set_font_size(style, normalSize);
|
format_set_font_size(style, normalSize);
|
||||||
@ -129,9 +139,17 @@ namespace platform {
|
|||||||
format_set_align(headerSmall, LXW_ALIGN_VERTICAL_CENTER);
|
format_set_align(headerSmall, LXW_ALIGN_VERTICAL_CENTER);
|
||||||
format_set_bg_color(headerSmall, lxw_color_t(colorOdd));
|
format_set_bg_color(headerSmall, lxw_color_t(colorOdd));
|
||||||
|
|
||||||
|
// Summary style
|
||||||
|
lxw_format* summaryStyle = workbook_add_format(workbook);
|
||||||
|
format_set_bold(summaryStyle);
|
||||||
|
format_set_font_size(summaryStyle, 16);
|
||||||
|
format_set_border(summaryStyle, LXW_BORDER_THIN);
|
||||||
|
format_set_align(summaryStyle, LXW_ALIGN_VERTICAL_CENTER);
|
||||||
|
|
||||||
styles["headerFirst"] = headerFirst;
|
styles["headerFirst"] = headerFirst;
|
||||||
styles["headerRest"] = headerRest;
|
styles["headerRest"] = headerRest;
|
||||||
styles["headerSmall"] = headerSmall;
|
styles["headerSmall"] = headerSmall;
|
||||||
|
styles["summaryStyle"] = summaryStyle;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ReportExcel::setProperties()
|
void ReportExcel::setProperties()
|
||||||
@ -173,7 +191,7 @@ namespace platform {
|
|||||||
locale::global(mylocale);
|
locale::global(mylocale);
|
||||||
cout.imbue(mylocale);
|
cout.imbue(mylocale);
|
||||||
stringstream oss;
|
stringstream oss;
|
||||||
string message = data["model"].get<string>() + " ver. " + data["version"].get<string>() +
|
string message = data["model"].get<string>() + " ver. " + data["version"].get<string>() + " " +
|
||||||
data["language"].get<string>() + " ver. " + data["language_version"].get<string>() +
|
data["language"].get<string>() + " ver. " + data["language_version"].get<string>() +
|
||||||
" with " + to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) +
|
" with " + to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) +
|
||||||
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>();
|
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>();
|
||||||
@ -211,6 +229,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
row = 6;
|
row = 6;
|
||||||
col = 0;
|
col = 0;
|
||||||
|
int hypSize = 22;
|
||||||
json lastResult;
|
json lastResult;
|
||||||
double totalScore = 0.0;
|
double totalScore = 0.0;
|
||||||
string hyperparameters;
|
string hyperparameters;
|
||||||
@ -224,7 +243,7 @@ namespace platform {
|
|||||||
writeDouble(row, col + 6, r["depth"].get<double>(), "floats");
|
writeDouble(row, col + 6, r["depth"].get<double>(), "floats");
|
||||||
writeDouble(row, col + 7, r["score"].get<double>(), "result");
|
writeDouble(row, col + 7, r["score"].get<double>(), "result");
|
||||||
writeDouble(row, col + 8, r["score_std"].get<double>(), "result");
|
writeDouble(row, col + 8, r["score_std"].get<double>(), "result");
|
||||||
const string status = "X";
|
const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>());
|
||||||
writeString(row, col + 9, status, "textCentered");
|
writeString(row, col + 9, status, "textCentered");
|
||||||
writeDouble(row, col + 10, r["time"].get<double>(), "time");
|
writeDouble(row, col + 10, r["time"].get<double>(), "time");
|
||||||
writeDouble(row, col + 11, r["time_std"].get<double>(), "time");
|
writeDouble(row, col + 11, r["time_std"].get<double>(), "time");
|
||||||
@ -236,11 +255,18 @@ namespace platform {
|
|||||||
oss << r["hyperparameters"];
|
oss << r["hyperparameters"];
|
||||||
hyperparameters = oss.str();
|
hyperparameters = oss.str();
|
||||||
}
|
}
|
||||||
|
if (hyperparameters.size() > hypSize) {
|
||||||
|
hypSize = hyperparameters.size();
|
||||||
|
}
|
||||||
writeString(row, col + 12, hyperparameters, "text");
|
writeString(row, col + 12, hyperparameters, "text");
|
||||||
lastResult = r;
|
lastResult = r;
|
||||||
totalScore += r["score"].get<double>();
|
totalScore += r["score"].get<double>();
|
||||||
row++;
|
row++;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
// Set the right column width of hyperparameters with the maximum length
|
||||||
|
worksheet_set_column(worksheet, 12, 12, hypSize + 1, NULL);
|
||||||
|
// Show totals if only one dataset is present in the result
|
||||||
if (data["results"].size() == 1) {
|
if (data["results"].size() == 1) {
|
||||||
for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
|
for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
|
||||||
row++;
|
row++;
|
||||||
@ -254,9 +280,52 @@ namespace platform {
|
|||||||
footer(totalScore, row);
|
footer(totalScore, row);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
string ReportExcel::compareResult(const string& dataset, double result)
|
||||||
|
{
|
||||||
|
string status = " ";
|
||||||
|
if (data["score_name"].get<string>() == "accuracy") {
|
||||||
|
auto dt = Datasets(Paths::datasets(), false);
|
||||||
|
dt.loadDataset(dataset);
|
||||||
|
auto numClasses = dt.getNClasses(dataset);
|
||||||
|
if (numClasses == 2) {
|
||||||
|
vector<int> distribution = dt.getClassesCounts(dataset);
|
||||||
|
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
|
||||||
|
int maxCategory = distance(distribution.begin(), maxValue);
|
||||||
|
double mark = maxCategory * (1 + margin);
|
||||||
|
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
|
||||||
|
auto item = summary.find(status);
|
||||||
|
if (item != summary.end()) {
|
||||||
|
summary[status]++;
|
||||||
|
} else {
|
||||||
|
summary[status] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
void ReportExcel::showSummary()
|
||||||
|
{
|
||||||
|
stringstream oss;
|
||||||
|
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
|
||||||
|
|
||||||
|
map<string, string> meaning = {
|
||||||
|
{Symbols::equal_best, "Equal to best"},
|
||||||
|
{Symbols::better_best, "Better than best"},
|
||||||
|
{Symbols::cross, "Less than or equal to ZeroR"},
|
||||||
|
{Symbols::upward_arrow, oss.str()}
|
||||||
|
};
|
||||||
|
for (const auto& item : summary) {
|
||||||
|
worksheet_write_string(worksheet, row + 2, 1, item.first.c_str(), styles["summaryStyle"]);
|
||||||
|
worksheet_write_number(worksheet, row + 2, 2, item.second, styles["summaryStyle"]);
|
||||||
|
worksheet_merge_range(worksheet, row + 2, 3, row + 2, 5, meaning.at(item.first).c_str(), styles["summaryStyle"]);
|
||||||
|
row += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ReportExcel::footer(double totalScore, int row)
|
void ReportExcel::footer(double totalScore, int row)
|
||||||
{
|
{
|
||||||
|
showSummary();
|
||||||
|
row += 2 + summary.size();
|
||||||
auto score = data["score_name"].get<string>();
|
auto score = data["score_name"].get<string>();
|
||||||
if (score == BestResult::scoreName()) {
|
if (score == BestResult::scoreName()) {
|
||||||
worksheet_merge_range(worksheet, row + 2, 1, row + 2, 5, (score + " compared to " + BestResult::title() + " .:").c_str(), styles["text_even"]);
|
worksheet_merge_range(worksheet, row + 2, 1, row + 2, 5, (score + " compared to " + BestResult::title() + " .:").c_str(), styles["text_even"]);
|
||||||
|
@ -8,9 +8,20 @@
|
|||||||
namespace platform {
|
namespace platform {
|
||||||
using namespace std;
|
using namespace std;
|
||||||
const int MAXLL = 128;
|
const int MAXLL = 128;
|
||||||
|
class Symbols {
|
||||||
|
public:
|
||||||
|
inline static const string check_mark{ "\u2714" };
|
||||||
|
inline static const string exclamation{ "\u2757" };
|
||||||
|
inline static const string black_star{ "\u2605" };
|
||||||
|
inline static const string cross{ "\u2717" };
|
||||||
|
inline static const string upward_arrow{ "\u27B6" };
|
||||||
|
inline static const string down_arrow{ "\u27B4" };
|
||||||
|
inline static const string equal_best{ check_mark };
|
||||||
|
inline static const string better_best{ black_star };
|
||||||
|
};
|
||||||
class ReportExcel : public ReportBase {
|
class ReportExcel : public ReportBase {
|
||||||
public:
|
public:
|
||||||
explicit ReportExcel(json data_) : ReportBase(data_) { createFile(); };
|
explicit ReportExcel(json data_);
|
||||||
virtual ~ReportExcel() { closeFile(); };
|
virtual ~ReportExcel() { closeFile(); };
|
||||||
private:
|
private:
|
||||||
void writeString(int row, int col, const string& text, const string& style = "");
|
void writeString(int row, int col, const string& text, const string& style = "");
|
||||||
@ -21,14 +32,17 @@ namespace platform {
|
|||||||
void setProperties();
|
void setProperties();
|
||||||
void createFile();
|
void createFile();
|
||||||
void closeFile();
|
void closeFile();
|
||||||
|
void showSummary();
|
||||||
lxw_workbook* workbook;
|
lxw_workbook* workbook;
|
||||||
lxw_worksheet* worksheet;
|
lxw_worksheet* worksheet;
|
||||||
map<string, lxw_format*> styles;
|
map<string, lxw_format*> styles;
|
||||||
int row = 0;
|
map<string, int> summary;
|
||||||
int normalSize = 14; //font size for report body
|
int row;
|
||||||
uint32_t colorTitle = 0xB1A0C7;
|
int normalSize; //font size for report body
|
||||||
uint32_t colorOdd = 0xDCE6F1;
|
uint32_t colorTitle;
|
||||||
uint32_t colorEven = 0xFDE9D9;
|
uint32_t colorOdd;
|
||||||
|
uint32_t colorEven;
|
||||||
|
double margin;
|
||||||
const string fileName = "some_results.xlsx";
|
const string fileName = "some_results.xlsx";
|
||||||
void header() override;
|
void header() override;
|
||||||
void body() override;
|
void body() override;
|
||||||
@ -36,6 +50,7 @@ namespace platform {
|
|||||||
void createStyle(const string& name, lxw_format* style, bool odd);
|
void createStyle(const string& name, lxw_format* style, bool odd);
|
||||||
void addColor(lxw_format* style, bool odd);
|
void addColor(lxw_format* style, bool odd);
|
||||||
lxw_format* efectiveStyle(const string& name);
|
lxw_format* efectiveStyle(const string& name);
|
||||||
|
string compareResult(const string& dataset, double result);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
#endif // !REPORTEXCEL_H
|
#endif // !REPORTEXCEL_H
|
@ -87,7 +87,7 @@ int main(int argc, char** argv)
|
|||||||
auto stratified = program.get<bool>("stratified");
|
auto stratified = program.get<bool>("stratified");
|
||||||
auto n_folds = program.get<int>("folds");
|
auto n_folds = program.get<int>("folds");
|
||||||
auto seeds = program.get<vector<int>>("seeds");
|
auto seeds = program.get<vector<int>>("seeds");
|
||||||
auto hyperparameters =program.get<string>("hyperparameters");
|
auto hyperparameters = program.get<string>("hyperparameters");
|
||||||
vector<string> filesToTest;
|
vector<string> filesToTest;
|
||||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||||
auto title = program.get<string>("title");
|
auto title = program.get<string>("title");
|
||||||
@ -102,7 +102,7 @@ int main(int argc, char** argv)
|
|||||||
}
|
}
|
||||||
filesToTest.push_back(file_name);
|
filesToTest.push_back(file_name);
|
||||||
} else {
|
} else {
|
||||||
filesToTest = platform::Datasets(path, true, platform::ARFF).getNames();
|
filesToTest = datasets.getNames();
|
||||||
saveResults = true;
|
saveResults = true;
|
||||||
}
|
}
|
||||||
/*
|
/*
|
||||||
|
Loading…
Reference in New Issue
Block a user