From c165a4bdda57ee9b73daf56074dbad2f61469b32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 17 May 2024 23:38:21 +0200 Subject: [PATCH] Fix refactor of static aggregate method --- src/main/Scores.cpp | 8 ++++---- src/main/Scores.h | 2 +- src/reports/ReportConsole.h | 1 - 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/main/Scores.cpp b/src/main/Scores.cpp index 435d9fb..6014504 100644 --- a/src/main/Scores.cpp +++ b/src/main/Scores.cpp @@ -40,11 +40,11 @@ namespace platform { } compute_accuracy_value(); } - static Score Scores::create_aggregate(json& data, std::string key) + Scores Scores::create_aggregate(json& data, std::string key) { - auto scores = Scores(result[key][0]); - for (int i = 1; i < result[key].size(); i++) { - auto score = Scores(result[key][i]); + auto scores = Scores(data[key][0]); + for (int i = 1; i < data[key].size(); i++) { + auto score = Scores(data[key][i]); scores.aggregate(score); } return scores; diff --git a/src/main/Scores.h b/src/main/Scores.h index 0d5a740..76c4ee4 100644 --- a/src/main/Scores.h +++ b/src/main/Scores.h @@ -10,7 +10,7 @@ namespace platform { public: Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector labels = {}); explicit Scores(json& confusion_matrix_); - static Score create_aggregate(json& data, std::string key); + static Scores create_aggregate(json& data, std::string key); float accuracy(); float f1_score(int num_class); float f1_weighted(); diff --git a/src/reports/ReportConsole.h b/src/reports/ReportConsole.h index c455848..73bcaec 100644 --- a/src/reports/ReportConsole.h +++ b/src/reports/ReportConsole.h @@ -6,7 +6,6 @@ #include "ReportBase.h" #include "main/Scores.h" - namespace platform { const int MAXL = 133; class ReportConsole : public ReportBase {