Compare commits
2 Commits
b639a2d79a
...
f5107abea7
Author | SHA1 | Date | |
---|---|---|---|
f5107abea7
|
|||
e64e281b63
|
@@ -91,7 +91,7 @@ namespace platform {
|
|||||||
}
|
}
|
||||||
void Statistics::computeWTL()
|
void Statistics::computeWTL()
|
||||||
{
|
{
|
||||||
// Compute the WTL matrix
|
// Compute the WTL matrix (Win Tie Loss)
|
||||||
for (int i = 0; i < nModels; ++i) {
|
for (int i = 0; i < nModels; ++i) {
|
||||||
wtl[i] = { 0, 0, 0 };
|
wtl[i] = { 0, 0, 0 };
|
||||||
}
|
}
|
||||||
|
@@ -4,7 +4,7 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include "RocAuc.h"
|
#include "RocAuc.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
|
|
||||||
double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels)
|
double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels)
|
||||||
{
|
{
|
||||||
size_t nClasses = y_proba.size(1);
|
size_t nClasses = y_proba.size(1);
|
||||||
@@ -48,6 +48,7 @@ namespace platform {
|
|||||||
double tp = 0, fp = 0;
|
double tp = 0, fp = 0;
|
||||||
double totalPos = std::count(y_test.begin(), y_test.end(), classIdx);
|
double totalPos = std::count(y_test.begin(), y_test.end(), classIdx);
|
||||||
double totalNeg = nSamples - totalPos;
|
double totalNeg = nSamples - totalPos;
|
||||||
|
if (totalPos == 0 || totalNeg == 0) return 0.5; // neutral AUC
|
||||||
|
|
||||||
for (const auto& [score, label] : scoresAndLabels) {
|
for (const auto& [score, label] : scoresAndLabels) {
|
||||||
if (label == 1) {
|
if (label == 1) {
|
||||||
|
Reference in New Issue
Block a user