diff --git a/src/main/Experiment.cpp b/src/main/Experiment.cpp index 156cd99..07e095e 100644 --- a/src/main/Experiment.cpp +++ b/src/main/Experiment.cpp @@ -16,7 +16,7 @@ namespace platform { ReportConsole report(result.getJson()); report.show(); if (classification_report) { - std::cout << Colors::BLUE() << report.showClassificationReport() << Colors::RESET(); + std::cout << report.showClassificationReport(Colors::BLUE()); } } void Experiment::show() diff --git a/src/main/Scores.cpp b/src/main/Scores.cpp index 72095be..737af29 100644 --- a/src/main/Scores.cpp +++ b/src/main/Scores.cpp @@ -1,5 +1,6 @@ #include #include "Scores.h" +#include "common/Colors.h" namespace platform { Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector labels) : num_classes(num_classes), labels(labels) { @@ -128,14 +129,14 @@ namespace platform { << std::setw(dlen) << std::right << support << std::endl; return oss.str(); } - std::string Scores::classification_report() + std::string Scores::classification_report(std::string color) { std::stringstream oss; for (int i = 0; i < num_classes; i++) { label_len = std::max(label_len, (int)labels[i].size()); } - oss << "Classification Report" << std::endl; - oss << "=====================" << std::endl; + oss << Colors::GREEN() << "Classification Report" << std::endl; + oss << "=====================" << std::endl << color; oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl; oss << std::string(label_len, ' ') << " ========= ========= ========= =========" << std::endl; for (int i = 0; i < num_classes; i++) { @@ -160,16 +161,19 @@ namespace platform { recall_avg /= num_classes; oss << classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total); oss << classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total); - oss << std::endl << "Confusion Matrix" << std::endl; - oss << "================" << std::endl; + oss << std::endl << Colors::GREEN() << "Confusion Matrix" << std::endl; + oss << "================" << std::endl << color; auto number = total > 1000 ? 4 : 3; for (int i = 0; i < num_classes; i++) { oss << std::right << std::setw(label_len) << labels[i] << " "; for (int j = 0; j < num_classes; j++) { + if (i == j) oss << Colors::GREEN(); oss << std::setw(number) << confusion_matrix[i][j].item() << " "; + if (i == j) oss << color; } oss << std::endl; } + oss << Colors::RESET(); return oss.str(); } json Scores::get_confusion_matrix_json(bool labels_as_keys) diff --git a/src/main/Scores.h b/src/main/Scores.h index 9de92a0..1f5bb8f 100644 --- a/src/main/Scores.h +++ b/src/main/Scores.h @@ -17,7 +17,7 @@ namespace platform { float precision(int num_class); float recall(int num_class); torch::Tensor get_confusion_matrix() { return confusion_matrix; } - std::string classification_report(); + std::string classification_report(std::string color = ""); json get_confusion_matrix_json(bool labels_as_keys = false); void aggregate(const Scores& a); private: diff --git a/src/reports/ReportConsole.cpp b/src/reports/ReportConsole.cpp index f029fc4..a44be15 100644 --- a/src/reports/ReportConsole.cpp +++ b/src/reports/ReportConsole.cpp @@ -136,7 +136,7 @@ namespace platform { sbody << std::string(MAXL, '*') << Colors::RESET() << std::endl; vbody.push_back(std::string(MAXL, '*') + Colors::RESET() + "\n"); if (lastResult.find("confusion_matrices") != lastResult.end() && (data["results"].size() == 1 || selectedIndex != -1)) { - vbody.push_back(Colors::BLUE() + showClassificationReport() + Colors::RESET()); + vbody.push_back(showClassificationReport(Colors::BLUE())); } } void ReportConsole::showSummary() @@ -169,7 +169,7 @@ namespace platform { std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!"); } } - std::string ReportConsole::showClassificationReport() + std::string ReportConsole::showClassificationReport(std::string color) { auto lastResult = data["results"][0]; if (data["results"].size() > 1 || lastResult.find("confusion_matrices") == lastResult.end()) @@ -180,6 +180,6 @@ namespace platform { auto score = Scores(item["confusion_matrices"][i]); scores.aggregate(score); } - return scores.classification_report(); + return scores.classification_report(color); } } \ No newline at end of file diff --git a/src/reports/ReportConsole.h b/src/reports/ReportConsole.h index c85cc39..a332e31 100644 --- a/src/reports/ReportConsole.h +++ b/src/reports/ReportConsole.h @@ -14,7 +14,7 @@ namespace platform { std::string fileReport(); std::string getHeader() { do_header(); do_body(); return sheader.str(); } std::vector& getBody() { return vbody; } - std::string showClassificationReport(); + std::string showClassificationReport(std::string color); private: int selectedIndex; std::string headerLine(const std::string& text, int utf);