Complete comparison with ZeroR

This commit is contained in:
Ricardo Montañana Gómez 2023-09-19 17:55:03 +02:00
parent 1bdfbd1620
commit f69f415b92
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
4 changed files with 99 additions and 15 deletions

View File

@ -5,12 +5,12 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc)
add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc)
add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc platformUtils.cc)
add_executable(list list.cc platformUtils Datasets.cc)
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux")
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so stdc++fs)
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so)
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp stdc++fs)
else()
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp)
endif()
target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}")

View File

@ -1,5 +1,6 @@
#include <sstream>
#include <locale>
#include "Datasets.h"
#include "ReportExcel.h"
#include "BestResult.h"
@ -12,6 +13,15 @@ namespace platform {
string do_grouping() const { return "\03"; }
};
ReportExcel::ReportExcel(json data_) : ReportBase(data_), row(0)
{
normalSize = 14; //font size for report body
colorTitle = 0xB1A0C7;
colorOdd = 0xDCE6F1;
colorEven = 0xFDE9D9;
margin = .1; // margin to add to ZeroR comparison
createFile();
}
lxw_format* ReportExcel::efectiveStyle(const string& style)
{
@ -41,7 +51,7 @@ namespace platform {
void ReportExcel::formatColumns()
{
worksheet_freeze_panes(worksheet, 6, 1);
vector<int> columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 14, 12, 50 };
vector<int> columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 15, 12, 23 };
for (int i = 0; i < columns_sizes.size(); ++i) {
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
}
@ -65,9 +75,9 @@ namespace platform {
} else if (name == "bodyHeader") {
format_set_bold(style);
format_set_font_size(style, normalSize);
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
format_set_align(style, LXW_ALIGN_CENTER);
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
format_set_border(style, LXW_BORDER_THIN);
format_set_bg_color(style, lxw_color_t(colorTitle));
} else if (name == "result") {
format_set_font_size(style, normalSize);
@ -129,9 +139,17 @@ namespace platform {
format_set_align(headerSmall, LXW_ALIGN_VERTICAL_CENTER);
format_set_bg_color(headerSmall, lxw_color_t(colorOdd));
// Summary style
lxw_format* summaryStyle = workbook_add_format(workbook);
format_set_bold(summaryStyle);
format_set_font_size(summaryStyle, 16);
format_set_border(summaryStyle, LXW_BORDER_THIN);
format_set_align(summaryStyle, LXW_ALIGN_VERTICAL_CENTER);
styles["headerFirst"] = headerFirst;
styles["headerRest"] = headerRest;
styles["headerSmall"] = headerSmall;
styles["summaryStyle"] = summaryStyle;
}
void ReportExcel::setProperties()
@ -173,7 +191,7 @@ namespace platform {
locale::global(mylocale);
cout.imbue(mylocale);
stringstream oss;
string message = data["model"].get<string>() + " ver. " + data["version"].get<string>() +
string message = data["model"].get<string>() + " ver. " + data["version"].get<string>() + " " +
data["language"].get<string>() + " ver. " + data["language_version"].get<string>() +
" with " + to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) +
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>();
@ -211,6 +229,7 @@ namespace platform {
}
row = 6;
col = 0;
int hypSize = 22;
json lastResult;
double totalScore = 0.0;
string hyperparameters;
@ -224,7 +243,7 @@ namespace platform {
writeDouble(row, col + 6, r["depth"].get<double>(), "floats");
writeDouble(row, col + 7, r["score"].get<double>(), "result");
writeDouble(row, col + 8, r["score_std"].get<double>(), "result");
const string status = "X";
const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>());
writeString(row, col + 9, status, "textCentered");
writeDouble(row, col + 10, r["time"].get<double>(), "time");
writeDouble(row, col + 11, r["time_std"].get<double>(), "time");
@ -236,11 +255,18 @@ namespace platform {
oss << r["hyperparameters"];
hyperparameters = oss.str();
}
if (hyperparameters.size() > hypSize) {
hypSize = hyperparameters.size();
}
writeString(row, col + 12, hyperparameters, "text");
lastResult = r;
totalScore += r["score"].get<double>();
row++;
}
// Set the right column width of hyperparameters with the maximum length
worksheet_set_column(worksheet, 12, 12, hypSize + 1, NULL);
// Show totals if only one dataset is present in the result
if (data["results"].size() == 1) {
for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
row++;
@ -254,9 +280,52 @@ namespace platform {
footer(totalScore, row);
}
}
string ReportExcel::compareResult(const string& dataset, double result)
{
string status = " ";
if (data["score_name"].get<string>() == "accuracy") {
auto dt = Datasets(Paths::datasets(), false);
dt.loadDataset(dataset);
auto numClasses = dt.getNClasses(dataset);
if (numClasses == 2) {
vector<int> distribution = dt.getClassesCounts(dataset);
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
int maxCategory = distance(distribution.begin(), maxValue);
double mark = maxCategory * (1 + margin);
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
auto item = summary.find(status);
if (item != summary.end()) {
summary[status]++;
} else {
summary[status] = 1;
}
}
}
return status;
}
void ReportExcel::showSummary()
{
stringstream oss;
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
map<string, string> meaning = {
{Symbols::equal_best, "Equal to best"},
{Symbols::better_best, "Better than best"},
{Symbols::cross, "Less than or equal to ZeroR"},
{Symbols::upward_arrow, oss.str()}
};
for (const auto& item : summary) {
worksheet_write_string(worksheet, row + 2, 1, item.first.c_str(), styles["summaryStyle"]);
worksheet_write_number(worksheet, row + 2, 2, item.second, styles["summaryStyle"]);
worksheet_merge_range(worksheet, row + 2, 3, row + 2, 5, meaning.at(item.first).c_str(), styles["summaryStyle"]);
row += 1;
}
}
void ReportExcel::footer(double totalScore, int row)
{
showSummary();
row += 2 + summary.size();
auto score = data["score_name"].get<string>();
if (score == BestResult::scoreName()) {
worksheet_merge_range(worksheet, row + 2, 1, row + 2, 5, (score + " compared to " + BestResult::title() + " .:").c_str(), styles["text_even"]);

View File

@ -8,9 +8,20 @@
namespace platform {
using namespace std;
const int MAXLL = 128;
class Symbols {
public:
inline static const string check_mark{ "\u2714" };
inline static const string exclamation{ "\u2757" };
inline static const string black_star{ "\u2605" };
inline static const string cross{ "\u2717" };
inline static const string upward_arrow{ "\u27B6" };
inline static const string down_arrow{ "\u27B4" };
inline static const string equal_best{ check_mark };
inline static const string better_best{ black_star };
};
class ReportExcel : public ReportBase {
public:
explicit ReportExcel(json data_) : ReportBase(data_) { createFile(); };
explicit ReportExcel(json data_);
virtual ~ReportExcel() { closeFile(); };
private:
void writeString(int row, int col, const string& text, const string& style = "");
@ -21,14 +32,17 @@ namespace platform {
void setProperties();
void createFile();
void closeFile();
void showSummary();
lxw_workbook* workbook;
lxw_worksheet* worksheet;
map<string, lxw_format*> styles;
int row = 0;
int normalSize = 14; //font size for report body
uint32_t colorTitle = 0xB1A0C7;
uint32_t colorOdd = 0xDCE6F1;
uint32_t colorEven = 0xFDE9D9;
map<string, int> summary;
int row;
int normalSize; //font size for report body
uint32_t colorTitle;
uint32_t colorOdd;
uint32_t colorEven;
double margin;
const string fileName = "some_results.xlsx";
void header() override;
void body() override;
@ -36,6 +50,7 @@ namespace platform {
void createStyle(const string& name, lxw_format* style, bool odd);
void addColor(lxw_format* style, bool odd);
lxw_format* efectiveStyle(const string& name);
string compareResult(const string& dataset, double result);
};
};
#endif // !REPORTEXCEL_H

View File

@ -87,7 +87,7 @@ int main(int argc, char** argv)
auto stratified = program.get<bool>("stratified");
auto n_folds = program.get<int>("folds");
auto seeds = program.get<vector<int>>("seeds");
auto hyperparameters =program.get<string>("hyperparameters");
auto hyperparameters = program.get<string>("hyperparameters");
vector<string> filesToTest;
auto datasets = platform::Datasets(path, true, platform::ARFF);
auto title = program.get<string>("title");
@ -102,7 +102,7 @@ int main(int argc, char** argv)
}
filesToTest.push_back(file_name);
} else {
filesToTest = platform::Datasets(path, true, platform::ARFF).getNames();
filesToTest = datasets.getNames();
saveResults = true;
}
/*