Files
Platform/src/main/RocAuc.h

21 lines
621 B
C++

#ifndef ROCAUC_H
#define ROCAUC_H
#include <torch/torch.h>
#include <vector>
#include <string>
#include <nlohmann/json.hpp>
namespace platform {
using json = nlohmann::ordered_json;
class RocAuc {
public:
RocAuc() = default;
double compute(const std::vector<std::vector<double>>& y_proba, const std::vector<int>& y_test);
double compute(const torch::Tensor& y_proba, const torch::Tensor& y_test);
private:
double compute_common(size_t nSamples, size_t classIdx);
std::vector<std::pair<double, int>> scoresAndLabels;
std::vector<int> y_test;
};
}
#endif