Add train classification report

This commit is contained in:
2024-05-14 11:45:54 +02:00
parent 99c9c6731f
commit 5c190d7c66
6 changed files with 85 additions and 28 deletions

View File

@@ -3,7 +3,6 @@
#include "best/BestScore.h"
#include "common/CLocale.h"
#include "ReportConsole.h"
#include "main/Scores.h"
namespace platform {
std::string ReportConsole::headerLine(const std::string& text, int utf = 0)
@@ -135,7 +134,7 @@ namespace platform {
}
sbody << std::string(MAXL, '*') << Colors::RESET() << std::endl;
vbody.push_back(std::string(MAXL, '*') + Colors::RESET() + "\n");
if (lastResult.find("confusion_matrices") != lastResult.end() && (data["results"].size() == 1 || selectedIndex != -1)) {
if (data["results"].size() == 1 || selectedIndex != -1) {
vbody.push_back(showClassificationReport(Colors::BLUE()));
}
}
@@ -169,17 +168,60 @@ namespace platform {
std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!");
}
}
std::string ReportConsole::showClassificationReport(std::string color)
Scores ReportConsole::aggregateScore(std::string key)
{
auto lastResult = data["results"][0];
if (data["results"].size() > 1 || lastResult.find("confusion_matrices") == lastResult.end())
return "";
auto item = data["results"][0];
auto scores = Scores(item["confusion_matrices"][0]);
for (int i = 1; i < item["confusion_matrices"].size(); i++) {
auto score = Scores(item["confusion_matrices"][i]);
auto scores = Scores(item[key][0]);
for (int i = 1; i < item[key].size(); i++) {
auto score = Scores(item[key][i]);
scores.aggregate(score);
}
return scores.classification_report(color);
return scores;
}
std::string ReportConsole::showClassificationReport(std::string color)
{
std::stringstream oss;
auto result = data["results"][0];
if (result.find("confusion_matrices") == result.end())
return "";
auto scores = aggregateScore("confusion_matrices");
auto output_test = scores.classification_report(color, "Test");
oss << Colors::BLUE();
if (result.find("confusion_matrices_train") == result.end()) {
for (auto& line : output_test) {
oss << line << std::endl;
}
oss << Colors::RESET();
return oss.str();
}
auto scores_train = aggregateScore("confusion_matrices_train");
auto output_train = scores_train.classification_report(color, "Train");
int maxLine = (*std::max_element(output_train.begin(), output_train.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
bool second_header = false;
int lines_header = 0;
std::string color_line;
std::string suffix = "";
for (int i = 0; i < output_train.size(); i++) {
if (i < 2 || second_header) {
color_line = Colors::GREEN();
} else {
color_line = Colors::BLUE();
if (lines_header > 1)
suffix = std::string(14, ' '); // compensate for the color
}
oss << color_line << std::left << std::setw(maxLine) << output_train[i]
<< suffix << Colors::BLUE() << " | " << color_line << std::left << std::setw(maxLine)
<< output_test[i] << std::endl;
if (output_train[i] == "" || (second_header && lines_header < 2)) {
lines_header++;
second_header = true;
} else {
second_header = false;
}
}
oss << Colors::RESET();
return oss.str();
}
}

View File

@@ -4,6 +4,8 @@
#include "common/Colors.h"
#include <sstream>
#include "ReportBase.h"
#include "main/Scores.h"
namespace platform {
const int MAXL = 133;
@@ -24,6 +26,7 @@ namespace platform {
void do_body();
void footer(double totalScore);
void showSummary() override;
Scores aggregateScore(std::string key);
std::stringstream sheader;
std::stringstream sbody;
std::vector<std::string> vbody;