diff --git a/src/main/Scores.cpp b/src/main/Scores.cpp index 435d9fb..6014504 100644 --- a/src/main/Scores.cpp +++ b/src/main/Scores.cpp @@ -40,11 +40,11 @@ namespace platform { } compute_accuracy_value(); } - static Score Scores::create_aggregate(json& data, std::string key) + Scores Scores::create_aggregate(json& data, std::string key) { - auto scores = Scores(result[key][0]); - for (int i = 1; i < result[key].size(); i++) { - auto score = Scores(result[key][i]); + auto scores = Scores(data[key][0]); + for (int i = 1; i < data[key].size(); i++) { + auto score = Scores(data[key][i]); scores.aggregate(score); } return scores; diff --git a/src/main/Scores.h b/src/main/Scores.h index 0d5a740..76c4ee4 100644 --- a/src/main/Scores.h +++ b/src/main/Scores.h @@ -10,7 +10,7 @@ namespace platform { public: Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector labels = {}); explicit Scores(json& confusion_matrix_); - static Score create_aggregate(json& data, std::string key); + static Scores create_aggregate(json& data, std::string key); float accuracy(); float f1_score(int num_class); float f1_weighted(); diff --git a/src/reports/ReportConsole.h b/src/reports/ReportConsole.h index c455848..73bcaec 100644 --- a/src/reports/ReportConsole.h +++ b/src/reports/ReportConsole.h @@ -6,7 +6,6 @@ #include "ReportBase.h" #include "main/Scores.h" - namespace platform { const int MAXL = 133; class ReportConsole : public ReportBase {