Git add Confusion Matrix to console report

This commit is contained in:
2024-05-13 10:40:25 +02:00
parent 2b480cdcb7
commit 8d20545fd2
2 changed files with 17 additions and 1 deletions

View File

@@ -160,6 +160,16 @@ namespace platform {
recall_avg /= num_classes;
oss << classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total);
oss << classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total);
oss << std::endl << "Confusion Matrix" << std::endl;
oss << "================" << std::endl;
auto number = total > 1000 ? 4 : 3;
for (int i = 0; i < num_classes; i++) {
oss << std::right << std::setw(label_len) << labels[i] << " ";
for (int j = 0; j < num_classes; j++) {
oss << std::setw(number) << confusion_matrix[i][j].item<int>() << " ";
}
oss << std::endl;
}
return oss.str();
}
json Scores::get_confusion_matrix_json(bool labels_as_keys)