Compare commits

...

2 Commits

5 changed files with 174 additions and 36 deletions

View File

@@ -22,7 +22,7 @@ add_executable(
b_best commands/b_best.cpp best/Statistics.cpp b_best commands/b_best.cpp best/Statistics.cpp
best/BestResultsExcel.cpp best/BestResults.cpp best/BestResultsExcel.cpp best/BestResults.cpp
common/Datasets.cpp common/Dataset.cpp common/Datasets.cpp common/Dataset.cpp
main/Models.cpp main/Models.cpp main/Scores.cpp
reports/ReportExcel.cpp reports/ReportBase.cpp reports/ExcelFile.cpp reports/ReportExcel.cpp reports/ReportBase.cpp reports/ExcelFile.cpp
results/Result.cpp results/Result.cpp
) )

View File

@@ -16,7 +16,7 @@ namespace platform {
confusion_matrix[actual][predicted] += 1; confusion_matrix[actual][predicted] += 1;
} }
} }
Scores::Scores(json& confusion_matrix_) Scores::Scores(const json& confusion_matrix_)
{ {
json values; json values;
total = 0; total = 0;
@@ -40,7 +40,7 @@ namespace platform {
} }
compute_accuracy_value(); compute_accuracy_value();
} }
Scores Scores::create_aggregate(json& data, std::string key) Scores Scores::create_aggregate(const json& data, const std::string key)
{ {
auto scores = Scores(data[key][0]); auto scores = Scores(data[key][0]);
for (int i = 1; i < data[key].size(); i++) { for (int i = 1; i < data[key].size(); i++) {
@@ -138,6 +138,25 @@ namespace platform {
<< std::setw(dlen) << std::right << support; << std::setw(dlen) << std::right << support;
return oss.str(); return oss.str();
} }
std::tuple<float, float, float, float> Scores::compute_averages()
{
float precision_avg = 0;
float recall_avg = 0;
float precision_wavg = 0;
float recall_wavg = 0;
for (int i = 0; i < num_classes; i++) {
int support = confusion_matrix[i].sum().item<int>();
precision_avg += precision(i);
precision_wavg += precision(i) * support;
recall_avg += recall(i);
recall_wavg += recall(i) * support;
}
precision_wavg /= total;
recall_wavg /= total;
precision_avg /= num_classes;
recall_avg /= num_classes;
return { precision_avg, recall_avg, precision_wavg, recall_wavg };
}
std::vector<std::string> Scores::classification_report(std::string color, std::string title) std::vector<std::string> Scores::classification_report(std::string color, std::string title)
{ {
std::stringstream oss; std::stringstream oss;
@@ -157,21 +176,7 @@ namespace platform {
report.push_back(" "); report.push_back(" ");
oss << classification_report_line("accuracy", 0, 0, accuracy(), total); oss << classification_report_line("accuracy", 0, 0, accuracy(), total);
report.push_back(oss.str()); oss.str(""); report.push_back(oss.str()); oss.str("");
float precision_avg = 0; auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages();
float recall_avg = 0;
float precision_wavg = 0;
float recall_wavg = 0;
for (int i = 0; i < num_classes; i++) {
int support = confusion_matrix[i].sum().item<int>();
precision_avg += precision(i);
precision_wavg += precision(i) * support;
recall_avg += recall(i);
recall_wavg += recall(i) * support;
}
precision_wavg /= total;
recall_wavg /= total;
precision_avg /= num_classes;
recall_avg /= num_classes;
report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total)); report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total));
report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total)); report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total));
report.push_back(""); report.push_back("");
@@ -189,17 +194,33 @@ namespace platform {
} }
return report; return report;
} }
json Scores::classification_report_json(std::string title)
{
json output;
output["title"] = "Classification Report using " + title + " dataset";
output["headers"] = { " ", "precision", "recall", "f1-score", "support" };
output["body"] = {};
for (int i = 0; i < num_classes; i++) {
output["body"].push_back({ labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item<int>() });
}
output["accuracy"] = { "accuracy", 0, 0, accuracy(), total };
auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages();
output["averages"] = { "macro avg", precision_avg, recall_avg, f1_macro(), total };
output["weighted"] = { "weighted avg", precision_wavg, recall_wavg, f1_weighted(), total };
output["confusion_matrix"] = get_confusion_matrix_json();
return output;
}
json Scores::get_confusion_matrix_json(bool labels_as_keys) json Scores::get_confusion_matrix_json(bool labels_as_keys)
{ {
json j; json output;
for (int i = 0; i < num_classes; i++) { for (int i = 0; i < num_classes; i++) {
auto r_ptr = confusion_matrix[i].data_ptr<int>(); auto r_ptr = confusion_matrix[i].data_ptr<int>();
if (labels_as_keys) { if (labels_as_keys) {
j[labels[i]] = std::vector<int>(r_ptr, r_ptr + num_classes); output[labels[i]] = std::vector<int>(r_ptr, r_ptr + num_classes);
} else { } else {
j[i] = std::vector<int>(r_ptr, r_ptr + num_classes); output[i] = std::vector<int>(r_ptr, r_ptr + num_classes);
} }
} }
return j; return output;
} }
} }

View File

@@ -4,15 +4,14 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include <xlsxwriter.h>
namespace platform { namespace platform {
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
class Scores { class Scores {
public: public:
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {}); Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
explicit Scores(json& confusion_matrix_); explicit Scores(const json& confusion_matrix_);
static Scores create_aggregate(json& data, std::string key); static Scores create_aggregate(const json& data, const std::string key);
float accuracy(); float accuracy();
float f1_score(int num_class); float f1_score(int num_class);
float f1_weighted(); float f1_weighted();
@@ -21,6 +20,7 @@ namespace platform {
float recall(int num_class); float recall(int num_class);
torch::Tensor get_confusion_matrix() { return confusion_matrix; } torch::Tensor get_confusion_matrix() { return confusion_matrix; }
std::vector<std::string> classification_report(std::string color = "", std::string title = ""); std::vector<std::string> classification_report(std::string color = "", std::string title = "");
json classification_report_json(std::string title = "");
json get_confusion_matrix_json(bool labels_as_keys = false); json get_confusion_matrix_json(bool labels_as_keys = false);
void aggregate(const Scores& a); void aggregate(const Scores& a);
private: private:
@@ -28,6 +28,7 @@ namespace platform {
void init_confusion_matrix(); void init_confusion_matrix();
void init_default_labels(); void init_default_labels();
void compute_accuracy_value(); void compute_accuracy_value();
std::tuple<float, float, float, float> compute_averages();
int num_classes; int num_classes;
float accuracy_value; float accuracy_value;
int total; int total;

View File

@@ -195,15 +195,7 @@ namespace platform {
} }
// Classificacion report // Classificacion report
if (lastResult.find("confusion_matrices") != lastResult.end()) { if (lastResult.find("confusion_matrices") != lastResult.end()) {
// auto score = platform2::Scores::create_aggregate(lastResult, "confusion_matrices"); create_classification_report(lastResult);
// row++;
// writeString(row, 1, "Classification Report", "bodyHeader");
// row++;
// auto output = platform2::Scores::classification_report("", "test");
// for (const auto& item : output) {
// writeString(row, 1, item, "text");
// row++;
// }
} }
// Set with of columns to show those totals completely // Set with of columns to show those totals completely
worksheet_set_column(worksheet, 1, 1, 12, NULL); worksheet_set_column(worksheet, 1, 1, 12, NULL);
@@ -215,7 +207,128 @@ namespace platform {
footer(totalScore, row); footer(totalScore, row);
} }
} }
void ReportExcel::create_classification_report(const json& result)
{
auto matrix_sheet = workbook_add_worksheet(workbook, "classif_report");
lxw_worksheet* tmp = worksheet;
worksheet = matrix_sheet;
if (matrix_sheet == NULL) {
throw std::invalid_argument("Couldn't create sheet classif_report");
}
worksheet_merge_range(matrix_sheet, 0, 0, 0, 5, "Classification Report", efectiveStyle("bodyHeader"));
int row = 2;
int col = 0;
if (result.find("confusion_matrices_train") != result.end()) {
// Train classification report
auto score = Scores::create_aggregate(result, "confusion_matrices_train");
auto train = score.classification_report_json("Train");
std::tie(row, col) = write_classification_report(train, row, 0);
int new_row = 0;
int new_col = col + 1;
for (int i = 0; i < result["confusion_matrices_train"].size(); ++i) {
auto item = result["confusion_matrices_train"][i];
auto score_item = Scores(item);
auto title = "Train Fold " + std::to_string(i);
std::tie(new_row, new_col) = write_classification_report(score_item.classification_report_json(title), 2, new_col);
new_col++;
}
}
// Test classification report
auto score = Scores::create_aggregate(result, "confusion_matrices");
auto test = score.classification_report_json("Test");
int init_row = ++row;
std::tie(row, col) = write_classification_report(test, init_row, 0);
int new_row = 0;
int new_col = col + 1;
for (int i = 0; i < result["confusion_matrices"].size(); ++i) {
auto item = result["confusion_matrices"][i];
auto score_item = Scores(item);
auto title = "Test Fold " + std::to_string(i);
std::tie(new_row, new_col) = write_classification_report(score_item.classification_report_json(title), init_row, new_col);
new_col++;
}
// Format columns (change size to fit the content)
for (int i = 0; i < new_col; ++i) {
// doesn't work with from col to col, so...
worksheet_set_column(worksheet, i, i, 12, NULL);
}
worksheet = tmp;
}
std::pair<int, int> ReportExcel::write_classification_report(const json& result, int init_row, int init_col)
{
int row = init_row;
auto text = result["title"].get<std::string>();
worksheet_merge_range(worksheet, row++, init_col, row, init_col + 5, text.c_str(), efectiveStyle("bodyHeader"));
int col = init_col + 2;
// Headers
bool first_item = true;
for (const auto& item : result["headers"]) {
auto text = item.get<std::string>();
if (first_item) {
first_item = false;
worksheet_merge_range(worksheet, row, init_col, row, init_col + 1, text.c_str(), efectiveStyle("bodyHeader"));
} else {
writeString(row, col++, text, "bodyHeader");
}
}
row++;
// Classes f1-score
for (const auto& item : result["body"]) {
col = init_col + 2;
for (const auto& value : item) {
if (value.is_string()) {
worksheet_merge_range(worksheet, row, init_col, row, init_col + 1, value.get<std::string>().c_str(), efectiveStyle("text"));
} else {
if (value.is_number_integer()) {
writeInt(row, col++, value.get<int>(), "ints");
} else {
writeDouble(row, col++, value.get<double>(), "result");
}
}
}
row++;
}
worksheet_merge_range(worksheet, row, init_col, row, init_col + 5, "", efectiveStyle("text"));
row++;
// Accuracy and average f1-score
for (const auto& item : { "accuracy", "averages", "weighted" }) {
col = init_col + 2;
for (const auto& value : result[item]) {
if (value.is_string()) {
worksheet_merge_range(worksheet, row, init_col, row, init_col + 1, value.get<std::string>().c_str(), efectiveStyle("text"));
} else {
if (value.is_number_integer()) {
writeInt(row, col++, value.get<int>(), "ints");
} else {
writeDouble(row, col++, value.get<double>(), "result");
}
}
}
row++;
}
// Confusion matrix
worksheet_merge_range(worksheet, row, init_col, row, init_col + 5, "", efectiveStyle("bodyHeader"));
row++;
auto n_items = result["confusion_matrix"].size();
worksheet_merge_range(worksheet, row, init_col, row, init_col + n_items + 1, "Confusion Matrix", efectiveStyle("bodyHeader"));
row++;
for (int i = 0; i < n_items; ++i) {
col = init_col + 2;
auto label = result["body"][i][0].get<std::string>();
worksheet_merge_range(worksheet, row, init_col, row, init_col + 1, label.c_str(), efectiveStyle("text"));
for (int j = 0; j < result["confusion_matrix"][i].size(); ++j) {
auto value = result["confusion_matrix"][i][j];
if (i == j) {
writeInt(row, col++, value.get<int>(), "ints_bold");
} else {
writeInt(row, col++, value.get<int>(), "ints");
}
}
row++;
}
int maxcol = std::max(5, int(init_col + n_items + 1));
return { row, maxcol };
}
void ReportExcel::showSummary() void ReportExcel::showSummary()
{ {
for (const auto& item : summary) { for (const auto& item : summary) {
@@ -225,7 +338,6 @@ namespace platform {
row += 1; row += 1;
} }
} }
void ReportExcel::footer(double totalScore, int row) void ReportExcel::footer(double totalScore, int row)
{ {
showSummary(); showSummary();

View File

@@ -1,5 +1,7 @@
#ifndef REPORT_EXCEL_H #ifndef REPORT_EXCEL_H
#define REPORT_EXCEL_H #define REPORT_EXCEL_H
#include <algorithm>
#include "main/Scores.h"
#include "common/Colors.h" #include "common/Colors.h"
#include "ReportBase.h" #include "ReportBase.h"
#include "ExcelFile.h" #include "ExcelFile.h"
@@ -19,6 +21,8 @@ namespace platform {
void showSummary() override; void showSummary() override;
void footer(double totalScore, int row); void footer(double totalScore, int row);
void append_notes(const json& r, int row); void append_notes(const json& r, int row);
void create_classification_report(const json& result);
std::pair<int, int> write_classification_report(const json& result, int init_row, int init_col);
void header_notes(int row); void header_notes(int row);
}; };
}; };