Begin classification report in excel

This commit is contained in:
2024-05-18 21:37:34 +02:00
parent b9e0c92334
commit 594adb0534
5 changed files with 113 additions and 37 deletions

View File

@@ -16,7 +16,7 @@ namespace platform {
confusion_matrix[actual][predicted] += 1;
}
}
Scores::Scores(json& confusion_matrix_)
Scores::Scores(const json& confusion_matrix_)
{
json values;
total = 0;
@@ -40,7 +40,7 @@ namespace platform {
}
compute_accuracy_value();
}
Scores Scores::create_aggregate(json& data, std::string key)
Scores Scores::create_aggregate(const json& data, const std::string key)
{
auto scores = Scores(data[key][0]);
for (int i = 1; i < data[key].size(); i++) {
@@ -138,6 +138,25 @@ namespace platform {
<< std::setw(dlen) << std::right << support;
return oss.str();
}
std::tuple<float, float, float, float> Scores::compute_averages()
{
float precision_avg = 0;
float recall_avg = 0;
float precision_wavg = 0;
float recall_wavg = 0;
for (int i = 0; i < num_classes; i++) {
int support = confusion_matrix[i].sum().item<int>();
precision_avg += precision(i);
precision_wavg += precision(i) * support;
recall_avg += recall(i);
recall_wavg += recall(i) * support;
}
precision_wavg /= total;
recall_wavg /= total;
precision_avg /= num_classes;
recall_avg /= num_classes;
return { precision_avg, recall_avg, precision_wavg, recall_wavg };
}
std::vector<std::string> Scores::classification_report(std::string color, std::string title)
{
std::stringstream oss;
@@ -157,21 +176,7 @@ namespace platform {
report.push_back(" ");
oss << classification_report_line("accuracy", 0, 0, accuracy(), total);
report.push_back(oss.str()); oss.str("");
float precision_avg = 0;
float recall_avg = 0;
float precision_wavg = 0;
float recall_wavg = 0;
for (int i = 0; i < num_classes; i++) {
int support = confusion_matrix[i].sum().item<int>();
precision_avg += precision(i);
precision_wavg += precision(i) * support;
recall_avg += recall(i);
recall_wavg += recall(i) * support;
}
precision_wavg /= total;
recall_wavg /= total;
precision_avg /= num_classes;
recall_avg /= num_classes;
auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages();
report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total));
report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total));
report.push_back("");
@@ -189,17 +194,33 @@ namespace platform {
}
return report;
}
json Scores::classification_report_json(std::string title)
{
json output;
output["title"] = "Classification Report using " + title + " dataset";
output["headers"] = { " ", "precision", "recall", "f1-score", "support" };
output["body"] = {};
for (int i = 0; i < num_classes; i++) {
output["body"].push_back({ labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item<int>() });
}
output["accuracy"] = { "accuracy", 0, 0, accuracy(), total };
auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages();
output["averages"] = { "macro avg", precision_avg, recall_avg, f1_macro(), total };
output["weighted"] = { "weighted avg", precision_wavg, recall_wavg, f1_weighted(), total };
output["confusion_matrix"] = get_confusion_matrix_json();
return output;
}
json Scores::get_confusion_matrix_json(bool labels_as_keys)
{
json j;
json output;
for (int i = 0; i < num_classes; i++) {
auto r_ptr = confusion_matrix[i].data_ptr<int>();
if (labels_as_keys) {
j[labels[i]] = std::vector<int>(r_ptr, r_ptr + num_classes);
output[labels[i]] = std::vector<int>(r_ptr, r_ptr + num_classes);
} else {
j[i] = std::vector<int>(r_ptr, r_ptr + num_classes);
output[i] = std::vector<int>(r_ptr, r_ptr + num_classes);
}
}
return j;
return output;
}
}