#include #include "Scores.h" #include "common/TensorUtils.hpp" // tensorToVector #include "common/Colors.h" namespace platform { Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_proba, int num_classes, std::vector labels) : num_classes(num_classes), labels(labels), y_test(y_test), y_proba(y_proba) { if (labels.size() == 0) { init_default_labels(); } total = y_test.size(0); auto y_pred = y_proba.argmax(1); accuracy_value = (y_pred == y_test).sum().item() / total; init_confusion_matrix(); for (int i = 0; i < total; i++) { int actual = y_test[i].item(); int predicted = y_pred[i].item(); confusion_matrix[actual][predicted] += 1; } } Scores::Scores(const json& confusion_matrix_) { json values; total = 0; num_classes = confusion_matrix_.size(); init_confusion_matrix(); int i = 0; for (const auto& item : confusion_matrix_.items()) { values = item.value(); json key = item.key(); if (key.is_number_integer()) { labels.push_back("Class " + std::to_string(key.get())); } else { labels.push_back(key.get()); } for (int j = 0; j < num_classes; ++j) { int value_int = values[j].get(); confusion_matrix[i][j] = value_int; total += value_int; } i++; } compute_accuracy_value(); } float Scores::auc() { size_t nSamples = y_test.numel(); if (nSamples == 0) return 0; // In binary classification problem there's no need to calculate the average of the AUCs auto nClasses = num_classes; if (num_classes == 2) nClasses = 1; auto y_testv = TensorUtils::tensorToVector(y_test); std::vector aucScores(nClasses, 0.0); std::vector> scoresAndLabels; for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) { if (classIdx >= y_proba.size(1)) { std::cerr << "AUC warning - class index out of range" << std::endl; return 0; } scoresAndLabels.clear(); for (size_t i = 0; i < nSamples; ++i) { scoresAndLabels.emplace_back(y_proba[i][classIdx].item(), y_testv[i] == classIdx ? 1 : 0); } std::sort(scoresAndLabels.begin(), scoresAndLabels.end(), std::greater<>()); std::vector tpr, fpr; double tp = 0, fp = 0; double totalPos = std::count(y_testv.begin(), y_testv.end(), classIdx); double totalNeg = nSamples - totalPos; for (const auto& [score, label] : scoresAndLabels) { if (label == 1) { tp += 1; } else { fp += 1; } tpr.push_back(tp / totalPos); fpr.push_back(fp / totalNeg); } double auc = 0.0; for (size_t i = 1; i < tpr.size(); ++i) { auc += 0.5 * (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]); } aucScores[classIdx] = auc; } return std::accumulate(aucScores.begin(), aucScores.end(), 0.0) / nClasses; } Scores Scores::create_aggregate(const json& data, const std::string key) { auto scores = Scores(data[key][0]); for (int i = 1; i < data[key].size(); i++) { auto score = Scores(data[key][i]); scores.aggregate(score); } return scores; } void Scores::compute_accuracy_value() { accuracy_value = 0; for (int i = 0; i < num_classes; i++) { accuracy_value += confusion_matrix[i][i].item(); } accuracy_value /= total; accuracy_value = std::min(accuracy_value, 1.0f); } void Scores::init_confusion_matrix() { confusion_matrix = torch::zeros({ num_classes, num_classes }, torch::kInt32); } void Scores::init_default_labels() { for (int i = 0; i < num_classes; i++) { labels.push_back("Class " + std::to_string(i)); } } void Scores::aggregate(const Scores& a) { if (a.num_classes != num_classes) throw std::invalid_argument("The number of classes must be the same"); confusion_matrix += a.confusion_matrix; total += a.total; compute_accuracy_value(); } float Scores::accuracy() { return accuracy_value; } float Scores::f1_score(int num_class) { // Compute f1_score in a one vs rest fashion auto precision_value = precision(num_class); auto recall_value = recall(num_class); if (precision_value + recall_value == 0) return 0; // Avoid division by zero (0/0 = 0) return 2 * precision_value * recall_value / (precision_value + recall_value); } float Scores::f1_weighted() { float f1_weighted = 0; for (int i = 0; i < num_classes; i++) { f1_weighted += confusion_matrix[i].sum().item() * f1_score(i); } return f1_weighted / total; } float Scores::f1_macro() { float f1_macro = 0; for (int i = 0; i < num_classes; i++) { f1_macro += f1_score(i); } return f1_macro / num_classes; } float Scores::precision(int num_class) { int tp = confusion_matrix[num_class][num_class].item(); int fp = confusion_matrix.index({ "...", num_class }).sum().item() - tp; int fn = confusion_matrix[num_class].sum().item() - tp; if (tp + fp == 0) return 0; // Avoid division by zero (0/0 = 0 return float(tp) / (tp + fp); } float Scores::recall(int num_class) { int tp = confusion_matrix[num_class][num_class].item(); int fp = confusion_matrix.index({ "...", num_class }).sum().item() - tp; int fn = confusion_matrix[num_class].sum().item() - tp; if (tp + fn == 0) return 0; // Avoid division by zero (0/0 = 0 return float(tp) / (tp + fn); } std::string Scores::classification_report_line(std::string label, float precision, float recall, float f1_score, int support) { std::stringstream oss; oss << std::right << std::setw(label_len) << label << " "; if (precision == 0) { oss << std::string(dlen, ' ') << " "; } else { oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << precision << " "; } if (recall == 0) { oss << std::string(dlen, ' ') << " "; } else { oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << recall << " "; } oss << std::setw(dlen) << std::setprecision(ndec) << std::fixed << f1_score << " " << std::setw(dlen) << std::right << support; return oss.str(); } std::tuple Scores::compute_averages() { float precision_avg = 0; float recall_avg = 0; float precision_wavg = 0; float recall_wavg = 0; for (int i = 0; i < num_classes; i++) { int support = confusion_matrix[i].sum().item(); precision_avg += precision(i); precision_wavg += precision(i) * support; recall_avg += recall(i); recall_wavg += recall(i) * support; } precision_wavg /= total; recall_wavg /= total; precision_avg /= num_classes; recall_avg /= num_classes; return { precision_avg, recall_avg, precision_wavg, recall_wavg }; } std::vector Scores::classification_report(std::string color, std::string title) { std::stringstream oss; std::vector report; for (int i = 0; i < num_classes; i++) { label_len = std::max(label_len, (int)labels[i].size()); } report.push_back("Classification Report using " + title + " dataset"); report.push_back("========================================="); oss << std::string(label_len, ' ') << " precision recall f1-score support"; report.push_back(oss.str()); oss.str(""); oss << std::string(label_len, ' ') << " ========= ========= ========= ========="; report.push_back(oss.str()); oss.str(""); for (int i = 0; i < num_classes; i++) { report.push_back(classification_report_line(labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item())); } report.push_back(" "); oss << classification_report_line("accuracy", 0, 0, accuracy(), total); report.push_back(oss.str()); oss.str(""); auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages(); report.push_back(classification_report_line("macro avg", precision_avg, recall_avg, f1_macro(), total)); report.push_back(classification_report_line("weighted avg", precision_wavg, recall_wavg, f1_weighted(), total)); report.push_back(""); report.push_back("Confusion Matrix"); report.push_back("================"); auto number = total > 1000 ? 4 : 3; for (int i = 0; i < num_classes; i++) { oss << std::right << std::setw(label_len) << labels[i] << " "; for (int j = 0; j < num_classes; j++) { if (i == j) oss << Colors::GREEN(); oss << std::setw(number) << confusion_matrix[i][j].item() << " "; if (i == j) oss << color; } report.push_back(oss.str()); oss.str(""); } return report; } json Scores::classification_report_json(std::string title) { json output; output["title"] = "Classification Report using " + title + " dataset"; output["headers"] = { " ", "precision", "recall", "f1-score", "support" }; output["body"] = {}; for (int i = 0; i < num_classes; i++) { output["body"].push_back({ labels[i], precision(i), recall(i), f1_score(i), confusion_matrix[i].sum().item() }); } output["accuracy"] = { "accuracy", 0, 0, accuracy(), total }; auto [precision_avg, recall_avg, precision_wavg, recall_wavg] = compute_averages(); output["averages"] = { "macro avg", precision_avg, recall_avg, f1_macro(), total }; output["weighted"] = { "weighted avg", precision_wavg, recall_wavg, f1_weighted(), total }; output["confusion_matrix"] = get_confusion_matrix_json(); return output; } json Scores::get_confusion_matrix_json(bool labels_as_keys) { json output; for (int i = 0; i < num_classes; i++) { auto r_ptr = confusion_matrix[i].data_ptr(); if (labels_as_keys) { output[labels[i]] = std::vector(r_ptr, r_ptr + num_classes); } else { output[i] = std::vector(r_ptr, r_ptr + num_classes); } } return output; } }