Add train classification report

This commit is contained in:
2024-05-14 11:45:54 +02:00
parent 99c9c6731f
commit 5c190d7c66
6 changed files with 85 additions and 28 deletions

View File

@@ -17,7 +17,7 @@ namespace platform {
float precision(int num_class);
float recall(int num_class);
torch::Tensor get_confusion_matrix() { return confusion_matrix; }
std::string classification_report(std::string color = "");
std::vector<std::string> classification_report(std::string color = "", std::string title = "");
json get_confusion_matrix_json(bool labels_as_keys = false);
void aggregate(const Scores& a);
private:
@@ -30,7 +30,7 @@ namespace platform {
int total;
std::vector<std::string> labels;
torch::Tensor confusion_matrix; // Rows ar actual, columns are predicted
int label_len = 12;
int label_len = 16;
int dlen = 9;
int ndec = 7;
};