Add colors to confusion matrix and classification report

This commit is contained in:
2024-05-14 00:41:29 +02:00
parent 8d20545fd2
commit 99c9c6731f
5 changed files with 15 additions and 11 deletions

View File

@@ -1,5 +1,6 @@
#include <sstream>
#include "Scores.h"
#include "common/Colors.h"
namespace platform {
Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels) : num_classes(num_classes), labels(labels)
{
@@ -128,14 +129,14 @@ namespace platform {
<< std::setw(dlen) << std::right << support << std::endl;
return oss.str();
}
std::string Scores::classification_report()
std::string Scores::classification_report(std::string color)
{
std::stringstream oss;
for (int i = 0; i < num_classes; i++) {
label_len = std::max(label_len, (int)labels[i].size());
}
oss << "Classification Report" << std::endl;
oss << "=====================" << std::endl;
oss << Colors::GREEN() << "Classification Report" << std::endl;
oss << "=====================" << std::endl << color;
oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl;
oss << std::string(label_len, ' ') << " ========= ========= ========= =========" << std::endl;
for (int i = 0; i < num_classes; i++) {
@@ -160,16 +161,19 @@ namespace platform {
recall_avg /= num_classes;
oss << classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total);
oss << classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total);
oss << std::endl << "Confusion Matrix" << std::endl;
oss << "================" << std::endl;
oss << std::endl << Colors::GREEN() << "Confusion Matrix" << std::endl;
oss << "================" << std::endl << color;
auto number = total > 1000 ? 4 : 3;
for (int i = 0; i < num_classes; i++) {
oss << std::right << std::setw(label_len) << labels[i] << " ";
for (int j = 0; j < num_classes; j++) {
if (i == j) oss << Colors::GREEN();
oss << std::setw(number) << confusion_matrix[i][j].item<int>() << " ";
if (i == j) oss << color;
}
oss << std::endl;
}
oss << Colors::RESET();
return oss.str();
}
json Scores::get_confusion_matrix_json(bool labels_as_keys)