Begin classification report in excel
This commit is contained in:
@@ -22,7 +22,7 @@ add_executable(
|
|||||||
b_best commands/b_best.cpp best/Statistics.cpp
|
b_best commands/b_best.cpp best/Statistics.cpp
|
||||||
best/BestResultsExcel.cpp best/BestResults.cpp
|
best/BestResultsExcel.cpp best/BestResults.cpp
|
||||||
common/Datasets.cpp common/Dataset.cpp
|
common/Datasets.cpp common/Dataset.cpp
|
||||||
main/Models.cpp
|
main/Models.cpp main/Scores.cpp
|
||||||
reports/ReportExcel.cpp reports/ReportBase.cpp reports/ExcelFile.cpp
|
reports/ReportExcel.cpp reports/ReportBase.cpp reports/ExcelFile.cpp
|
||||||
results/Result.cpp
|
results/Result.cpp
|
||||||
)
|
)
|
||||||
|
@@ -16,7 +16,7 @@ namespace platform {
|
|||||||
confusion_matrix[actual][predicted] += 1;
|
confusion_matrix[actual][predicted] += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Scores::Scores(json& confusion_matrix_)
|
Scores::Scores(const json& confusion_matrix_)
|
||||||
{
|
{
|
||||||
json values;
|
json values;
|
||||||
total = 0;
|
total = 0;
|
||||||
@@ -40,7 +40,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
compute_accuracy_value();
|
compute_accuracy_value();
|
||||||
}
|
}
|
||||||
Scores Scores::create_aggregate(json& data, std::string key)
|
Scores Scores::create_aggregate(const json& data, const std::string key)
|
||||||
{
|
{
|
||||||
auto scores = Scores(data[key][0]);
|
auto scores = Scores(data[key][0]);
|
||||||
for (int i = 1; i < data[key].size(); i++) {
|
for (int i = 1; i < data[key].size(); i++) {
|
||||||
@@ -138,6 +138,25 @@ namespace platform {
|
|||||||
<< std::setw(dlen) << std::right << support;
|
<< std::setw(dlen) << std::right << support;
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
std::tuple<float, float, float, float> Scores::compute_averages()
|
||||||
|
{
|
||||||
|
float precision_avg = 0;
|
||||||
|
float recall_avg = 0;
|
||||||
|
float precision_wavg = 0;
|
||||||
|
float recall_wavg = 0;
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
int support = confusion_matrix[i].sum().item<int>();
|
||||||
|
precision_avg += precision(i);
|
||||||
|
precision_wavg += precision(i) * support;
|
||||||
|
recall_avg += recall(i);
|
||||||
|
recall_wavg += recall(i) * support;
|
||||||
|
}
|
||||||
|
precision_wavg /= total;
|
||||||
|
recall_wavg /= total;
|
||||||
|
precision_avg /= num_classes;
|
||||||
|
recall_avg /= num_classes;
|
||||||
|
return { precision_avg, recall_avg, precision_wavg, recall_wavg };
|
||||||
|
}
|
||||||
std::vector<std::string> Scores::classification_report(std::string color, std::string title)
|
std::vector<std::string> Scores::classification_report(std::string color, std::string title)
|
||||||
{
|
{
|
||||||
std::stringstream oss;
|
std::stringstream oss;
|
||||||
@@ -157,21 +176,7 @@ namespace platform {
|
|||||||
report.push_back(" ");
|
report.push_back(" ");
|
||||||
oss << classification_report_line("accuracy", 0, 0, accuracy(), total);
|
oss << classification_report_line("accuracy", 0, 0, accuracy(), total);
|
||||||
report.push_back(oss.str()); oss.str("");
|
report.push_back(oss.str()); oss.str("");
|
||||||
float precision_avg = 0;
|
auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages();
|
||||||
float recall_avg = 0;
|
|
||||||
float precision_wavg = 0;
|
|
||||||
float recall_wavg = 0;
|
|
||||||
for (int i = 0; i < num_classes; i++) {
|
|
||||||
int support = confusion_matrix[i].sum().item<int>();
|
|
||||||
precision_avg += precision(i);
|
|
||||||
precision_wavg += precision(i) * support;
|
|
||||||
recall_avg += recall(i);
|
|
||||||
recall_wavg += recall(i) * support;
|
|
||||||
}
|
|
||||||
precision_wavg /= total;
|
|
||||||
recall_wavg /= total;
|
|
||||||
precision_avg /= num_classes;
|
|
||||||
recall_avg /= num_classes;
|
|
||||||
report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total));
|
report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total));
|
||||||
report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total));
|
report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total));
|
||||||
report.push_back("");
|
report.push_back("");
|
||||||
@@ -189,17 +194,33 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
return report;
|
return report;
|
||||||
}
|
}
|
||||||
|
json Scores::classification_report_json(std::string title)
|
||||||
|
{
|
||||||
|
json output;
|
||||||
|
output["title"] = "Classification Report using " + title + " dataset";
|
||||||
|
output["headers"] = { " ", "precision", "recall", "f1-score", "support" };
|
||||||
|
output["body"] = {};
|
||||||
|
for (int i = 0; i < num_classes; i++) {
|
||||||
|
output["body"].push_back({ labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item<int>() });
|
||||||
|
}
|
||||||
|
output["accuracy"] = { "accuracy", 0, 0, accuracy(), total };
|
||||||
|
auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages();
|
||||||
|
output["averages"] = { "macro avg", precision_avg, recall_avg, f1_macro(), total };
|
||||||
|
output["weighted"] = { "weighted avg", precision_wavg, recall_wavg, f1_weighted(), total };
|
||||||
|
output["confusion_matrix"] = get_confusion_matrix_json();
|
||||||
|
return output;
|
||||||
|
}
|
||||||
json Scores::get_confusion_matrix_json(bool labels_as_keys)
|
json Scores::get_confusion_matrix_json(bool labels_as_keys)
|
||||||
{
|
{
|
||||||
json j;
|
json output;
|
||||||
for (int i = 0; i < num_classes; i++) {
|
for (int i = 0; i < num_classes; i++) {
|
||||||
auto r_ptr = confusion_matrix[i].data_ptr<int>();
|
auto r_ptr = confusion_matrix[i].data_ptr<int>();
|
||||||
if (labels_as_keys) {
|
if (labels_as_keys) {
|
||||||
j[labels[i]] = std::vector<int>(r_ptr, r_ptr + num_classes);
|
output[labels[i]] = std::vector<int>(r_ptr, r_ptr + num_classes);
|
||||||
} else {
|
} else {
|
||||||
j[i] = std::vector<int>(r_ptr, r_ptr + num_classes);
|
output[i] = std::vector<int>(r_ptr, r_ptr + num_classes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return j;
|
return output;
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -4,15 +4,14 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <nlohmann/json.hpp>
|
#include <nlohmann/json.hpp>
|
||||||
#include <xlsxwriter.h>
|
|
||||||
|
|
||||||
namespace platform {
|
namespace platform {
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
class Scores {
|
class Scores {
|
||||||
public:
|
public:
|
||||||
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
|
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
|
||||||
explicit Scores(json& confusion_matrix_);
|
explicit Scores(const json& confusion_matrix_);
|
||||||
static Scores create_aggregate(json& data, std::string key);
|
static Scores create_aggregate(const json& data, const std::string key);
|
||||||
float accuracy();
|
float accuracy();
|
||||||
float f1_score(int num_class);
|
float f1_score(int num_class);
|
||||||
float f1_weighted();
|
float f1_weighted();
|
||||||
@@ -21,6 +20,7 @@ namespace platform {
|
|||||||
float recall(int num_class);
|
float recall(int num_class);
|
||||||
torch::Tensor get_confusion_matrix() { return confusion_matrix; }
|
torch::Tensor get_confusion_matrix() { return confusion_matrix; }
|
||||||
std::vector<std::string> classification_report(std::string color = "", std::string title = "");
|
std::vector<std::string> classification_report(std::string color = "", std::string title = "");
|
||||||
|
json classification_report_json(std::string title = "");
|
||||||
json get_confusion_matrix_json(bool labels_as_keys = false);
|
json get_confusion_matrix_json(bool labels_as_keys = false);
|
||||||
void aggregate(const Scores& a);
|
void aggregate(const Scores& a);
|
||||||
private:
|
private:
|
||||||
@@ -28,6 +28,7 @@ namespace platform {
|
|||||||
void init_confusion_matrix();
|
void init_confusion_matrix();
|
||||||
void init_default_labels();
|
void init_default_labels();
|
||||||
void compute_accuracy_value();
|
void compute_accuracy_value();
|
||||||
|
std::tuple<float, float, float, float> compute_averages();
|
||||||
int num_classes;
|
int num_classes;
|
||||||
float accuracy_value;
|
float accuracy_value;
|
||||||
int total;
|
int total;
|
||||||
|
@@ -195,27 +195,78 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
// Classificacion report
|
// Classificacion report
|
||||||
if (lastResult.find("confusion_matrices") != lastResult.end()) {
|
if (lastResult.find("confusion_matrices") != lastResult.end()) {
|
||||||
// auto score = platform2::Scores::create_aggregate(lastResult, "confusion_matrices");
|
create_classification_report(lastResult);
|
||||||
// row++;
|
|
||||||
// writeString(row, 1, "Classification Report", "bodyHeader");
|
|
||||||
// row++;
|
|
||||||
// auto output = platform2::Scores::classification_report("", "test");
|
|
||||||
// for (const auto& item : output) {
|
|
||||||
// writeString(row, 1, item, "text");
|
|
||||||
// row++;
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
// Set with of columns to show those totals completely
|
// Set with of columns to show those totals completely
|
||||||
worksheet_set_column(worksheet, 1, 1, 12, NULL);
|
for (int i = 0; i < 5; ++i) {
|
||||||
for (int i = 2; i < 7; ++i) {
|
|
||||||
// doesn't work with from col to col, so...
|
// doesn't work with from col to col, so...
|
||||||
worksheet_set_column(worksheet, i, i, 15, NULL);
|
worksheet_set_column(worksheet, i, i, 12, NULL);
|
||||||
}
|
}
|
||||||
|
worksheet_set_column(worksheet, 5, 5, 7, NULL);
|
||||||
} else {
|
} else {
|
||||||
footer(totalScore, row);
|
footer(totalScore, row);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void ReportExcel::create_classification_report(const json& result)
|
||||||
|
{
|
||||||
|
auto matrix_sheet = workbook_add_worksheet(workbook, "classif_report");
|
||||||
|
lxw_worksheet* tmp = worksheet;
|
||||||
|
worksheet = matrix_sheet;
|
||||||
|
if (matrix_sheet == NULL) {
|
||||||
|
throw std::invalid_argument("Couldn't create sheet classif_report");
|
||||||
|
}
|
||||||
|
worksheet_merge_range(matrix_sheet, 0, 0, 0, 5, "Classification Report", efectiveStyle("bodyHeader"));
|
||||||
|
int row = 3;
|
||||||
|
if (result.find("confusion_matrices_train") != result.end()) {
|
||||||
|
auto score = Scores::create_aggregate(result, "confusion_matrices_train");
|
||||||
|
auto train = score.classification_report_json("Train");
|
||||||
|
row = write_classification_report(train, row);
|
||||||
|
}
|
||||||
|
auto score = Scores::create_aggregate(result, "confusion_matrices");
|
||||||
|
auto test = score.classification_report_json("Test");
|
||||||
|
write_classification_report(test, ++row);
|
||||||
|
for (int i = 1; i < 6; ++i) {
|
||||||
|
// doesn't work with from col to col, so...
|
||||||
|
worksheet_set_column(worksheet, i, i, 15, NULL);
|
||||||
|
}
|
||||||
|
worksheet = tmp;
|
||||||
|
}
|
||||||
|
int ReportExcel::write_classification_report(const json& result, int row)
|
||||||
|
{
|
||||||
|
auto text = result["title"].get<std::string>().c_str();
|
||||||
|
std::cout << "title: " << text << std::endl;
|
||||||
|
worksheet_merge_range(worksheet, row, 0, row, 5, text, efectiveStyle("bodyHeader"));
|
||||||
|
int col = 2;
|
||||||
|
row++;
|
||||||
|
bool first_item = true;
|
||||||
|
for (const auto& item : result["headers"]) {
|
||||||
|
auto text = item.get<std::string>().c_str();
|
||||||
|
if (first_item) {
|
||||||
|
first_item = false;
|
||||||
|
worksheet_merge_range(worksheet, row, 0, row, 1, text, efectiveStyle("bodyHeader"));
|
||||||
|
} else {
|
||||||
|
writeString(row, col++, text, "bodyHeader");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
row++;
|
||||||
|
for (const auto& item : result["body"]) {
|
||||||
|
col = 2;
|
||||||
|
for (const auto& value : item) {
|
||||||
|
if (value.is_string()) {
|
||||||
|
worksheet_merge_range(worksheet, row, 0, row, 1, value.get<std::string>().c_str(), efectiveStyle("text"));
|
||||||
|
} else {
|
||||||
|
if (value.is_number_integer()) {
|
||||||
|
writeInt(row, col++, value.get<int>(), "result");
|
||||||
|
} else {
|
||||||
|
writeDouble(row, col++, value.get<double>(), "result");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
row++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return row;
|
||||||
|
|
||||||
|
}
|
||||||
void ReportExcel::showSummary()
|
void ReportExcel::showSummary()
|
||||||
{
|
{
|
||||||
for (const auto& item : summary) {
|
for (const auto& item : summary) {
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
#ifndef REPORT_EXCEL_H
|
#ifndef REPORT_EXCEL_H
|
||||||
#define REPORT_EXCEL_H
|
#define REPORT_EXCEL_H
|
||||||
|
#include "main/Scores.h"
|
||||||
#include "common/Colors.h"
|
#include "common/Colors.h"
|
||||||
#include "ReportBase.h"
|
#include "ReportBase.h"
|
||||||
#include "ExcelFile.h"
|
#include "ExcelFile.h"
|
||||||
@@ -19,6 +20,8 @@ namespace platform {
|
|||||||
void showSummary() override;
|
void showSummary() override;
|
||||||
void footer(double totalScore, int row);
|
void footer(double totalScore, int row);
|
||||||
void append_notes(const json& r, int row);
|
void append_notes(const json& r, int row);
|
||||||
|
void create_classification_report(const json& result);
|
||||||
|
int write_classification_report(const json& result, int row);
|
||||||
void header_notes(int row);
|
void header_notes(int row);
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user