diff --git a/src/main/Scores.cpp b/src/main/Scores.cpp index ebf0914..435d9fb 100644 --- a/src/main/Scores.cpp +++ b/src/main/Scores.cpp @@ -40,6 +40,15 @@ namespace platform { } compute_accuracy_value(); } + static Score Scores::create_aggregate(json& data, std::string key) + { + auto scores = Scores(result[key][0]); + for (int i = 1; i < result[key].size(); i++) { + auto score = Scores(result[key][i]); + scores.aggregate(score); + } + return scores; + } void Scores::compute_accuracy_value() { accuracy_value = 0; diff --git a/src/main/Scores.h b/src/main/Scores.h index f54b35e..0d5a740 100644 --- a/src/main/Scores.h +++ b/src/main/Scores.h @@ -10,6 +10,7 @@ namespace platform { public: Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector labels = {}); explicit Scores(json& confusion_matrix_); + static Score create_aggregate(json& data, std::string key); float accuracy(); float f1_score(int num_class); float f1_weighted(); diff --git a/src/reports/ReportConsole.cpp b/src/reports/ReportConsole.cpp index 07fc8b4..e48fc54 100644 --- a/src/reports/ReportConsole.cpp +++ b/src/reports/ReportConsole.cpp @@ -186,13 +186,13 @@ namespace platform { int lines_header = 0; std::string color_line; std::string suffix = ""; - auto scores = aggregateScore(result, "confusion_matrices"); + auto scores = Scores::create_aggregate(result, "confusion_matrices"); auto output_test = scores.classification_report(color, "Test"); int maxLine = (*std::max_element(output_test.begin(), output_test.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size(); bool train_data = result.find("confusion_matrices_train") != result.end(); std::vector output_train; if (train_data) { - auto scores_train = aggregateScore(result, "confusion_matrices_train"); + auto scores_train = Scores::create_aggregate(result, "confusion_matrices_train"); output_train = scores_train.classification_report(color, "Train"); } oss << Colors::BLUE();