Add train classification report

This commit is contained in:
2024-05-14 11:45:54 +02:00
parent 99c9c6731f
commit 5c190d7c66
6 changed files with 85 additions and 28 deletions

View File

@@ -126,24 +126,28 @@ namespace platform {
oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << recall << " ";
}
oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << f1_score << " "
<< std::setw(dlen) << std::right << support << std::endl;
<< std::setw(dlen) << std::right << support;
return oss.str();
}
std::string Scores::classification_report(std::string color)
std::vector<std::string> Scores::classification_report(std::string color, std::string title)
{
std::stringstream oss;
std::vector<std::string> report;
for (int i = 0; i < num_classes; i++) {
label_len = std::max(label_len, (int)labels[i].size());
}
oss << Colors::GREEN() << "Classification Report" << std::endl;
oss << "=====================" << std::endl << color;
oss << std::string(label_len, ' ') << " precision recall f1-score support" << std::endl;
oss << std::string(label_len, ' ') << " ========= ========= ========= =========" << std::endl;
report.push_back("Classification Report using " + title + " dataset");
report.push_back("=========================================");
oss << std::string(label_len, ' ') << " precision recall f1-score support";
report.push_back(oss.str()); oss.str("");
oss << std::string(label_len, ' ') << " ========= ========= ========= =========";
report.push_back(oss.str()); oss.str("");
for (int i = 0; i < num_classes; i++) {
oss << classification_report_line(labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item<int>());
report.push_back(classification_report_line(labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item<int>()));
}
oss << std::endl;
report.push_back(" ");
oss << classification_report_line("accuracy", 0, 0, accuracy(), total);
report.push_back(oss.str()); oss.str("");
float precision_avg = 0;
float recall_avg = 0;
float precision_wavg = 0;
@@ -159,10 +163,11 @@ namespace platform {
recall_wavg /= total;
precision_avg /= num_classes;
recall_avg /= num_classes;
oss << classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total);
oss << classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total);
oss << std::endl << Colors::GREEN() << "Confusion Matrix" << std::endl;
oss << "================" << std::endl << color;
report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total));
report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total));
report.push_back("");
report.push_back("Confusion Matrix");
report.push_back("================");
auto number = total > 1000 ? 4 : 3;
for (int i = 0; i < num_classes; i++) {
oss << std::right << std::setw(label_len) << labels[i] << " ";
@@ -171,10 +176,9 @@ namespace platform {
oss << std::setw(number) << confusion_matrix[i][j].item<int>() << " ";
if (i == j) oss << color;
}
oss << std::endl;
report.push_back(oss.str()); oss.str("");
}
oss << Colors::RESET();
return oss.str();
return report;
}
json Scores::get_confusion_matrix_json(bool labels_as_keys)
{