Add train classification report
This commit is contained in:
@@ -17,7 +17,7 @@ namespace platform {
|
||||
float precision(int num_class);
|
||||
float recall(int num_class);
|
||||
torch::Tensor get_confusion_matrix() { return confusion_matrix; }
|
||||
std::string classification_report(std::string color = "");
|
||||
std::vector<std::string> classification_report(std::string color = "", std::string title = "");
|
||||
json get_confusion_matrix_json(bool labels_as_keys = false);
|
||||
void aggregate(const Scores& a);
|
||||
private:
|
||||
@@ -30,7 +30,7 @@ namespace platform {
|
||||
int total;
|
||||
std::vector<std::string> labels;
|
||||
torch::Tensor confusion_matrix; // Rows ar actual, columns are predicted
|
||||
int label_len = 12;
|
||||
int label_len = 16;
|
||||
int dlen = 9;
|
||||
int ndec = 7;
|
||||
};
|
||||
|
Reference in New Issue
Block a user