RocAuc refactor to speed up binary classif. problems
This commit is contained in:
@@ -1,9 +1,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "common/Colors.h"
|
|
||||||
#include "RocAuc.h"
|
#include "RocAuc.h"
|
||||||
namespace platform {
|
namespace platform {
|
||||||
std::vector<int> tensorToVector(const torch::Tensor& tensor)
|
std::vector<int> tensorToVector(const torch::Tensor& tensor)
|
||||||
@@ -30,8 +28,10 @@ namespace platform {
|
|||||||
double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels)
|
double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels)
|
||||||
{
|
{
|
||||||
size_t nClasses = y_proba.size(1);
|
size_t nClasses = y_proba.size(1);
|
||||||
|
// In binary classification problem there's no need to calculate the average of the AUCs
|
||||||
|
if (nClasses == 2)
|
||||||
|
nClasses = 1;
|
||||||
size_t nSamples = y_proba.size(0);
|
size_t nSamples = y_proba.size(0);
|
||||||
assert(nSamples = y_test.size(0));
|
|
||||||
y_test = tensorToVector(labels);
|
y_test = tensorToVector(labels);
|
||||||
std::vector<double> aucScores(nClasses, 0.0);
|
std::vector<double> aucScores(nClasses, 0.0);
|
||||||
for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) {
|
for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) {
|
||||||
@@ -47,6 +47,9 @@ namespace platform {
|
|||||||
{
|
{
|
||||||
y_test = labels;
|
y_test = labels;
|
||||||
size_t nClasses = y_proba[0].size();
|
size_t nClasses = y_proba[0].size();
|
||||||
|
// In binary classification problem there's no need to calculate the average of the AUCs
|
||||||
|
if (nClasses == 2)
|
||||||
|
nClasses = 1;
|
||||||
size_t nSamples = y_proba.size();
|
size_t nSamples = y_proba.size();
|
||||||
std::vector<double> aucScores(nClasses, 0.0);
|
std::vector<double> aucScores(nClasses, 0.0);
|
||||||
for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) {
|
for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) {
|
||||||
|
Reference in New Issue
Block a user