diff --git a/src/main/RocAuc.cpp b/src/main/RocAuc.cpp index d396d76..7c301f8 100644 --- a/src/main/RocAuc.cpp +++ b/src/main/RocAuc.cpp @@ -1,9 +1,7 @@ #include -#include #include #include #include -#include "common/Colors.h" #include "RocAuc.h" namespace platform { std::vector tensorToVector(const torch::Tensor& tensor) @@ -30,8 +28,10 @@ namespace platform { double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels) { size_t nClasses = y_proba.size(1); + // In binary classification problem there's no need to calculate the average of the AUCs + if (nClasses == 2) + nClasses = 1; size_t nSamples = y_proba.size(0); - assert(nSamples = y_test.size(0)); y_test = tensorToVector(labels); std::vector aucScores(nClasses, 0.0); for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) { @@ -47,6 +47,9 @@ namespace platform { { y_test = labels; size_t nClasses = y_proba[0].size(); + // In binary classification problem there's no need to calculate the average of the AUCs + if (nClasses == 2) + nClasses = 1; size_t nSamples = y_proba.size(); std::vector aucScores(nClasses, 0.0); for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) {