Complete comparison with ZeroR
This commit is contained in:
parent
1bdfbd1620
commit
f69f415b92
@ -5,12 +5,12 @@ include_directories(${BayesNet_SOURCE_DIR}/lib/mdlp)
|
||||
include_directories(${BayesNet_SOURCE_DIR}/lib/argparse/include)
|
||||
include_directories(${BayesNet_SOURCE_DIR}/lib/json/include)
|
||||
add_executable(main main.cc Folding.cc platformUtils.cc Experiment.cc Datasets.cc Models.cc ReportConsole.cc ReportBase.cc)
|
||||
add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc)
|
||||
add_executable(manage manage.cc Results.cc ReportConsole.cc ReportExcel.cc ReportBase.cc Datasets.cc platformUtils.cc)
|
||||
add_executable(list list.cc platformUtils Datasets.cc)
|
||||
target_link_libraries(main BayesNet ArffFiles mdlp "${TORCH_LIBRARIES}")
|
||||
if (${CMAKE_HOST_SYSTEM_NAME} MATCHES "Linux")
|
||||
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so stdc++fs)
|
||||
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so)
|
||||
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp stdc++fs)
|
||||
else()
|
||||
target_link_libraries(manage "${TORCH_LIBRARIES}" libxlsxwriter.so ArffFiles mdlp)
|
||||
endif()
|
||||
target_link_libraries(list ArffFiles mdlp "${TORCH_LIBRARIES}")
|
@ -1,5 +1,6 @@
|
||||
#include <sstream>
|
||||
#include <locale>
|
||||
#include "Datasets.h"
|
||||
#include "ReportExcel.h"
|
||||
#include "BestResult.h"
|
||||
|
||||
@ -12,6 +13,15 @@ namespace platform {
|
||||
|
||||
string do_grouping() const { return "\03"; }
|
||||
};
|
||||
ReportExcel::ReportExcel(json data_) : ReportBase(data_), row(0)
|
||||
{
|
||||
normalSize = 14; //font size for report body
|
||||
colorTitle = 0xB1A0C7;
|
||||
colorOdd = 0xDCE6F1;
|
||||
colorEven = 0xFDE9D9;
|
||||
margin = .1; // margin to add to ZeroR comparison
|
||||
createFile();
|
||||
}
|
||||
|
||||
lxw_format* ReportExcel::efectiveStyle(const string& style)
|
||||
{
|
||||
@ -41,7 +51,7 @@ namespace platform {
|
||||
void ReportExcel::formatColumns()
|
||||
{
|
||||
worksheet_freeze_panes(worksheet, 6, 1);
|
||||
vector<int> columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 14, 12, 50 };
|
||||
vector<int> columns_sizes = { 22, 10, 9, 7, 12, 12, 12, 12, 12, 3, 15, 12, 23 };
|
||||
for (int i = 0; i < columns_sizes.size(); ++i) {
|
||||
worksheet_set_column(worksheet, i, i, columns_sizes.at(i), NULL);
|
||||
}
|
||||
@ -65,9 +75,9 @@ namespace platform {
|
||||
} else if (name == "bodyHeader") {
|
||||
format_set_bold(style);
|
||||
format_set_font_size(style, normalSize);
|
||||
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
|
||||
format_set_align(style, LXW_ALIGN_CENTER);
|
||||
format_set_align(style, LXW_ALIGN_VERTICAL_CENTER);
|
||||
format_set_border(style, LXW_BORDER_THIN);
|
||||
format_set_bg_color(style, lxw_color_t(colorTitle));
|
||||
} else if (name == "result") {
|
||||
format_set_font_size(style, normalSize);
|
||||
@ -129,9 +139,17 @@ namespace platform {
|
||||
format_set_align(headerSmall, LXW_ALIGN_VERTICAL_CENTER);
|
||||
format_set_bg_color(headerSmall, lxw_color_t(colorOdd));
|
||||
|
||||
// Summary style
|
||||
lxw_format* summaryStyle = workbook_add_format(workbook);
|
||||
format_set_bold(summaryStyle);
|
||||
format_set_font_size(summaryStyle, 16);
|
||||
format_set_border(summaryStyle, LXW_BORDER_THIN);
|
||||
format_set_align(summaryStyle, LXW_ALIGN_VERTICAL_CENTER);
|
||||
|
||||
styles["headerFirst"] = headerFirst;
|
||||
styles["headerRest"] = headerRest;
|
||||
styles["headerSmall"] = headerSmall;
|
||||
styles["summaryStyle"] = summaryStyle;
|
||||
}
|
||||
|
||||
void ReportExcel::setProperties()
|
||||
@ -173,7 +191,7 @@ namespace platform {
|
||||
locale::global(mylocale);
|
||||
cout.imbue(mylocale);
|
||||
stringstream oss;
|
||||
string message = data["model"].get<string>() + " ver. " + data["version"].get<string>() +
|
||||
string message = data["model"].get<string>() + " ver. " + data["version"].get<string>() + " " +
|
||||
data["language"].get<string>() + " ver. " + data["language_version"].get<string>() +
|
||||
" with " + to_string(data["folds"].get<int>()) + " Folds cross validation and " + to_string(data["seeds"].size()) +
|
||||
" random seeds. " + data["date"].get<string>() + " " + data["time"].get<string>();
|
||||
@ -211,6 +229,7 @@ namespace platform {
|
||||
}
|
||||
row = 6;
|
||||
col = 0;
|
||||
int hypSize = 22;
|
||||
json lastResult;
|
||||
double totalScore = 0.0;
|
||||
string hyperparameters;
|
||||
@ -224,7 +243,7 @@ namespace platform {
|
||||
writeDouble(row, col + 6, r["depth"].get<double>(), "floats");
|
||||
writeDouble(row, col + 7, r["score"].get<double>(), "result");
|
||||
writeDouble(row, col + 8, r["score_std"].get<double>(), "result");
|
||||
const string status = "X";
|
||||
const string status = compareResult(r["dataset"].get<string>(), r["score"].get<double>());
|
||||
writeString(row, col + 9, status, "textCentered");
|
||||
writeDouble(row, col + 10, r["time"].get<double>(), "time");
|
||||
writeDouble(row, col + 11, r["time_std"].get<double>(), "time");
|
||||
@ -236,11 +255,18 @@ namespace platform {
|
||||
oss << r["hyperparameters"];
|
||||
hyperparameters = oss.str();
|
||||
}
|
||||
if (hyperparameters.size() > hypSize) {
|
||||
hypSize = hyperparameters.size();
|
||||
}
|
||||
writeString(row, col + 12, hyperparameters, "text");
|
||||
lastResult = r;
|
||||
totalScore += r["score"].get<double>();
|
||||
row++;
|
||||
|
||||
}
|
||||
// Set the right column width of hyperparameters with the maximum length
|
||||
worksheet_set_column(worksheet, 12, 12, hypSize + 1, NULL);
|
||||
// Show totals if only one dataset is present in the result
|
||||
if (data["results"].size() == 1) {
|
||||
for (const string& group : { "scores_train", "scores_test", "times_train", "times_test" }) {
|
||||
row++;
|
||||
@ -254,9 +280,52 @@ namespace platform {
|
||||
footer(totalScore, row);
|
||||
}
|
||||
}
|
||||
string ReportExcel::compareResult(const string& dataset, double result)
|
||||
{
|
||||
string status = " ";
|
||||
if (data["score_name"].get<string>() == "accuracy") {
|
||||
auto dt = Datasets(Paths::datasets(), false);
|
||||
dt.loadDataset(dataset);
|
||||
auto numClasses = dt.getNClasses(dataset);
|
||||
if (numClasses == 2) {
|
||||
vector<int> distribution = dt.getClassesCounts(dataset);
|
||||
vector<int>::iterator maxValue = max_element(distribution.begin(), distribution.end());
|
||||
int maxCategory = distance(distribution.begin(), maxValue);
|
||||
double mark = maxCategory * (1 + margin);
|
||||
status = result < mark ? Symbols::cross : result > mark ? Symbols::upward_arrow : "=";
|
||||
auto item = summary.find(status);
|
||||
if (item != summary.end()) {
|
||||
summary[status]++;
|
||||
} else {
|
||||
summary[status] = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
void ReportExcel::showSummary()
|
||||
{
|
||||
stringstream oss;
|
||||
oss << "Better than ZeroR + " << setprecision(1) << fixed << margin * 100 << "%";
|
||||
|
||||
map<string, string> meaning = {
|
||||
{Symbols::equal_best, "Equal to best"},
|
||||
{Symbols::better_best, "Better than best"},
|
||||
{Symbols::cross, "Less than or equal to ZeroR"},
|
||||
{Symbols::upward_arrow, oss.str()}
|
||||
};
|
||||
for (const auto& item : summary) {
|
||||
worksheet_write_string(worksheet, row + 2, 1, item.first.c_str(), styles["summaryStyle"]);
|
||||
worksheet_write_number(worksheet, row + 2, 2, item.second, styles["summaryStyle"]);
|
||||
worksheet_merge_range(worksheet, row + 2, 3, row + 2, 5, meaning.at(item.first).c_str(), styles["summaryStyle"]);
|
||||
row += 1;
|
||||
}
|
||||
}
|
||||
|
||||
void ReportExcel::footer(double totalScore, int row)
|
||||
{
|
||||
showSummary();
|
||||
row += 2 + summary.size();
|
||||
auto score = data["score_name"].get<string>();
|
||||
if (score == BestResult::scoreName()) {
|
||||
worksheet_merge_range(worksheet, row + 2, 1, row + 2, 5, (score + " compared to " + BestResult::title() + " .:").c_str(), styles["text_even"]);
|
||||
|
@ -8,9 +8,20 @@
|
||||
namespace platform {
|
||||
using namespace std;
|
||||
const int MAXLL = 128;
|
||||
class Symbols {
|
||||
public:
|
||||
inline static const string check_mark{ "\u2714" };
|
||||
inline static const string exclamation{ "\u2757" };
|
||||
inline static const string black_star{ "\u2605" };
|
||||
inline static const string cross{ "\u2717" };
|
||||
inline static const string upward_arrow{ "\u27B6" };
|
||||
inline static const string down_arrow{ "\u27B4" };
|
||||
inline static const string equal_best{ check_mark };
|
||||
inline static const string better_best{ black_star };
|
||||
};
|
||||
class ReportExcel : public ReportBase {
|
||||
public:
|
||||
explicit ReportExcel(json data_) : ReportBase(data_) { createFile(); };
|
||||
explicit ReportExcel(json data_);
|
||||
virtual ~ReportExcel() { closeFile(); };
|
||||
private:
|
||||
void writeString(int row, int col, const string& text, const string& style = "");
|
||||
@ -21,14 +32,17 @@ namespace platform {
|
||||
void setProperties();
|
||||
void createFile();
|
||||
void closeFile();
|
||||
void showSummary();
|
||||
lxw_workbook* workbook;
|
||||
lxw_worksheet* worksheet;
|
||||
map<string, lxw_format*> styles;
|
||||
int row = 0;
|
||||
int normalSize = 14; //font size for report body
|
||||
uint32_t colorTitle = 0xB1A0C7;
|
||||
uint32_t colorOdd = 0xDCE6F1;
|
||||
uint32_t colorEven = 0xFDE9D9;
|
||||
map<string, int> summary;
|
||||
int row;
|
||||
int normalSize; //font size for report body
|
||||
uint32_t colorTitle;
|
||||
uint32_t colorOdd;
|
||||
uint32_t colorEven;
|
||||
double margin;
|
||||
const string fileName = "some_results.xlsx";
|
||||
void header() override;
|
||||
void body() override;
|
||||
@ -36,6 +50,7 @@ namespace platform {
|
||||
void createStyle(const string& name, lxw_format* style, bool odd);
|
||||
void addColor(lxw_format* style, bool odd);
|
||||
lxw_format* efectiveStyle(const string& name);
|
||||
string compareResult(const string& dataset, double result);
|
||||
};
|
||||
};
|
||||
#endif // !REPORTEXCEL_H
|
@ -87,7 +87,7 @@ int main(int argc, char** argv)
|
||||
auto stratified = program.get<bool>("stratified");
|
||||
auto n_folds = program.get<int>("folds");
|
||||
auto seeds = program.get<vector<int>>("seeds");
|
||||
auto hyperparameters =program.get<string>("hyperparameters");
|
||||
auto hyperparameters = program.get<string>("hyperparameters");
|
||||
vector<string> filesToTest;
|
||||
auto datasets = platform::Datasets(path, true, platform::ARFF);
|
||||
auto title = program.get<string>("title");
|
||||
@ -102,7 +102,7 @@ int main(int argc, char** argv)
|
||||
}
|
||||
filesToTest.push_back(file_name);
|
||||
} else {
|
||||
filesToTest = platform::Datasets(path, true, platform::ARFF).getNames();
|
||||
filesToTest = datasets.getNames();
|
||||
saveResults = true;
|
||||
}
|
||||
/*
|
||||
|
Loading…
Reference in New Issue
Block a user