Complete reporconsole with classification report

This commit is contained in:
2024-05-14 13:22:13 +02:00
parent f8f3ca28dc
commit 30a6d5e60d
2 changed files with 18 additions and 12 deletions

View File

@@ -135,7 +135,7 @@ namespace platform {
sbody << std::string(MAXL, '*') << Colors::RESET() << std::endl; sbody << std::string(MAXL, '*') << Colors::RESET() << std::endl;
vbody.push_back(std::string(MAXL, '*') + Colors::RESET() + "\n"); vbody.push_back(std::string(MAXL, '*') + Colors::RESET() + "\n");
if (data["results"].size() == 1 || selectedIndex != -1) { if (data["results"].size() == 1 || selectedIndex != -1) {
vbody.push_back(showClassificationReport(Colors::BLUE())); vbody.push_back(buildClassificationReport(lastResult, Colors::BLUE()));
} }
} }
void ReportConsole::showSummary() void ReportConsole::showSummary()
@@ -168,34 +168,31 @@ namespace platform {
std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!"); std::cout << headerLine("*** Best Results File not found. Couldn't compare any result!");
} }
} }
Scores ReportConsole::aggregateScore(std::string key) Scores ReportConsole::aggregateScore(json& result, std::string key)
{ {
auto lastResult = data["results"][0]; auto scores = Scores(result[key][0]);
auto item = data["results"][0]; for (int i = 1; i < result[key].size(); i++) {
auto scores = Scores(item[key][0]); auto score = Scores(result[key][i]);
for (int i = 1; i < item[key].size(); i++) {
auto score = Scores(item[key][i]);
scores.aggregate(score); scores.aggregate(score);
} }
return scores; return scores;
} }
std::string ReportConsole::showClassificationReport(std::string color) std::string ReportConsole::buildClassificationReport(json& result, std::string color)
{ {
std::stringstream oss; std::stringstream oss;
auto result = data["results"][0];
if (result.find("confusion_matrices") == result.end()) if (result.find("confusion_matrices") == result.end())
return ""; return "";
bool second_header = false; bool second_header = false;
int lines_header = 0; int lines_header = 0;
std::string color_line; std::string color_line;
std::string suffix = ""; std::string suffix = "";
auto scores = aggregateScore("confusion_matrices"); auto scores = aggregateScore(result, "confusion_matrices");
auto output_test = scores.classification_report(color, "Test"); auto output_test = scores.classification_report(color, "Test");
int maxLine = (*std::max_element(output_test.begin(), output_test.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); int maxLine = (*std::max_element(output_test.begin(), output_test.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
bool train_data = result.find("confusion_matrices_train") != result.end(); bool train_data = result.find("confusion_matrices_train") != result.end();
std::vector<std::string> output_train; std::vector<std::string> output_train;
if (train_data) { if (train_data) {
auto scores_train = aggregateScore("confusion_matrices_train"); auto scores_train = aggregateScore(result, "confusion_matrices_train");
output_train = scores_train.classification_report(color, "Train"); output_train = scores_train.classification_report(color, "Train");
} }
oss << Colors::BLUE(); oss << Colors::BLUE();
@@ -224,4 +221,12 @@ namespace platform {
oss << Colors::RESET(); oss << Colors::RESET();
return oss.str(); return oss.str();
} }
std::string ReportConsole::showClassificationReport(std::string color)
{
std::stringstream oss;
for (auto& result : data["results"]) {
oss << buildClassificationReport(result, color);
}
return oss.str();
}
} }

View File

@@ -20,13 +20,14 @@ namespace platform {
private: private:
int selectedIndex; int selectedIndex;
std::string headerLine(const std::string& text, int utf); std::string headerLine(const std::string& text, int utf);
std::string buildClassificationReport(json& result, std::string color);
void header() override; void header() override;
void do_header(); void do_header();
void body() override; void body() override;
void do_body(); void do_body();
void footer(double totalScore); void footer(double totalScore);
void showSummary() override; void showSummary() override;
Scores aggregateScore(std::string key); Scores aggregateScore(json& result, std::string key);
std::stringstream sheader; std::stringstream sheader;
std::stringstream sbody; std::stringstream sbody;
std::vector<std::string> vbody; std::vector<std::string> vbody;