Compare commits

...

2 Commits

2 changed files with 3 additions and 2 deletions

View File

@@ -91,7 +91,7 @@ namespace platform {
} }
void Statistics::computeWTL() void Statistics::computeWTL()
{ {
// Compute the WTL matrix // Compute the WTL matrix (Win Tie Loss)
for (int i = 0; i < nModels; ++i) { for (int i = 0; i < nModels; ++i) {
wtl[i] = { 0, 0, 0 }; wtl[i] = { 0, 0, 0 };
} }

View File

@@ -4,7 +4,7 @@
#include <utility> #include <utility>
#include "RocAuc.h" #include "RocAuc.h"
namespace platform { namespace platform {
double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels) double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels)
{ {
size_t nClasses = y_proba.size(1); size_t nClasses = y_proba.size(1);
@@ -48,6 +48,7 @@ namespace platform {
double tp = 0, fp = 0; double tp = 0, fp = 0;
double totalPos = std::count(y_test.begin(), y_test.end(), classIdx); double totalPos = std::count(y_test.begin(), y_test.end(), classIdx);
double totalNeg = nSamples - totalPos; double totalNeg = nSamples - totalPos;
if (totalPos == 0 || totalNeg == 0) return 0.5; // neutral AUC
for (const auto& [score, label] : scoresAndLabels) { for (const auto& [score, label] : scoresAndLabels) {
if (label == 1) { if (label == 1) {