Add colors to confusion matrix and classification report
This commit is contained in:
@@ -16,7 +16,7 @@ namespace platform {
|
||||
ReportConsole report(result.getJson());
|
||||
report.show();
|
||||
if (classification_report) {
|
||||
std::cout << Colors::BLUE() << report.showClassificationReport() << Colors::RESET();
|
||||
std::cout << report.showClassificationReport(Colors::BLUE());
|
||||
}
|
||||
}
|
||||
void Experiment::show()
|
||||
|
@@ -1,5 +1,6 @@
|
||||
#include <sstream>
|
||||
#include "Scores.h"
|
||||
#include "common/Colors.h"
|
||||
namespace platform {
|
||||
Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels) : num_classes(num_classes), labels(labels)
|
||||
{
|
||||
@@ -128,14 +129,14 @@ namespace platform {
|
||||
<< std::setw(dlen) << std::right << support << std::endl;
|
||||
return oss.str();
|
||||
}
|
||||
std::string Scores::classification_report()
|
||||
std::string Scores::classification_report(std::string color)
|
||||
{
|
||||
std::stringstream oss;
|
||||
for (int i = 0; i < num_classes; i++) {
|
||||
label_len = std::max(label_len, (int)labels[i].size());
|
||||
}
|
||||
oss << "Classification Report" << std::endl;
|
||||
oss << "=====================" << std::endl;
|
||||
oss << Colors::GREEN() << "Classification Report" << std::endl;
|
||||
oss << "=====================" << std::endl << color;
|
||||
oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl;
|
||||
oss << std::string(label_len, ' ') << " ========= ========= ========= =========" << std::endl;
|
||||
for (int i = 0; i < num_classes; i++) {
|
||||
@@ -160,16 +161,19 @@ namespace platform {
|
||||
recall_avg /= num_classes;
|
||||
oss << classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total);
|
||||
oss << classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total);
|
||||
oss << std::endl << "Confusion Matrix" << std::endl;
|
||||
oss << "================" << std::endl;
|
||||
oss << std::endl << Colors::GREEN() << "Confusion Matrix" << std::endl;
|
||||
oss << "================" << std::endl << color;
|
||||
auto number = total > 1000 ? 4 : 3;
|
||||
for (int i = 0; i < num_classes; i++) {
|
||||
oss << std::right << std::setw(label_len) << labels[i] << " ";
|
||||
for (int j = 0; j < num_classes; j++) {
|
||||
if (i == j) oss << Colors::GREEN();
|
||||
oss << std::setw(number) << confusion_matrix[i][j].item<int>() << " ";
|
||||
if (i == j) oss << color;
|
||||
}
|
||||
oss << std::endl;
|
||||
}
|
||||
oss << Colors::RESET();
|
||||
return oss.str();
|
||||
}
|
||||
json Scores::get_confusion_matrix_json(bool labels_as_keys)
|
||||
|
@@ -17,7 +17,7 @@ namespace platform {
|
||||
float precision(int num_class);
|
||||
float recall(int num_class);
|
||||
torch::Tensor get_confusion_matrix() { return confusion_matrix; }
|
||||
std::string classification_report();
|
||||
std::string classification_report(std::string color = "");
|
||||
json get_confusion_matrix_json(bool labels_as_keys = false);
|
||||
void aggregate(const Scores& a);
|
||||
private:
|
||||
|
@@ -136,7 +136,7 @@ namespace platform {
|
||||
sbody << std::string(MAXL, '*') << Colors::RESET() << std::endl;
|
||||
vbody.push_back(std::string(MAXL, '*') + Colors::RESET() + "\n");
|
||||
if (lastResult.find("confusion_matrices") != lastResult.end() && (data["results"].size() == 1 || selectedIndex != -1)) {
|
||||
vbody.push_back(Colors::BLUE() + showClassificationReport() + Colors::RESET());
|
||||
vbody.push_back(showClassificationReport(Colors::BLUE()));
|
||||
}
|
||||
}
|
||||
void ReportConsole::showSummary()
|
||||
@@ -169,7 +169,7 @@ namespace platform {
|
||||
std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!");
|
||||
}
|
||||
}
|
||||
std::string ReportConsole::showClassificationReport()
|
||||
std::string ReportConsole::showClassificationReport(std::string color)
|
||||
{
|
||||
auto lastResult = data["results"][0];
|
||||
if (data["results"].size() > 1 || lastResult.find("confusion_matrices") == lastResult.end())
|
||||
@@ -180,6 +180,6 @@ namespace platform {
|
||||
auto score = Scores(item["confusion_matrices"][i]);
|
||||
scores.aggregate(score);
|
||||
}
|
||||
return scores.classification_report();
|
||||
return scores.classification_report(color);
|
||||
}
|
||||
}
|
@@ -14,7 +14,7 @@ namespace platform {
|
||||
std::string fileReport();
|
||||
std::string getHeader() { do_header(); do_body(); return sheader.str(); }
|
||||
std::vector<std::string>& getBody() { return vbody; }
|
||||
std::string showClassificationReport();
|
||||
std::string showClassificationReport(std::string color);
|
||||
private:
|
||||
int selectedIndex;
|
||||
std::string headerLine(const std::string& text, int utf);
|
||||
|
Reference in New Issue
Block a user