Refactor aggregate score to a constructor

This commit is contained in:
2024-05-17 22:52:13 +02:00
parent 577351eda5
commit 49a36904dc
3 changed files with 12 additions and 2 deletions

View File

@@ -40,6 +40,15 @@ namespace platform {
}
compute_accuracy_value();
}
static Score Scores::create_aggregate(json& data, std::string key)
{
auto scores = Scores(result[key][0]);
for (int i = 1; i < result[key].size(); i++) {
auto score = Scores(result[key][i]);
scores.aggregate(score);
}
return scores;
}
void Scores::compute_accuracy_value()
{
accuracy_value = 0;

View File

@@ -10,6 +10,7 @@ namespace platform {
public:
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
explicit Scores(json& confusion_matrix_);
static Score create_aggregate(json& data, std::string key);
float accuracy();
float f1_score(int num_class);
float f1_weighted();

View File

@@ -186,13 +186,13 @@ namespace platform {
int lines_header = 0;
std::string color_line;
std::string suffix = "";
auto scores = aggregateScore(result, "confusion_matrices");
auto scores = Scores::create_aggregate(result, "confusion_matrices");
auto output_test = scores.classification_report(color, "Test");
int maxLine = (*std::max_element(output_test.begin(), output_test.end(), [](const std::string& a, const std::string& b) { return a.size() < b.size(); })).size();
bool train_data = result.find("confusion_matrices_train") != result.end();
std::vector<std::string> output_train;
if (train_data) {
auto scores_train = aggregateScore(result, "confusion_matrices_train");
auto scores_train = Scores::create_aggregate(result, "confusion_matrices_train");
output_train = scores_train.classification_report(color, "Train");
}
oss << Colors::BLUE();