Add roc-auc-ovr as score to b_main

This commit is contained in:
2024-07-14 12:48:33 +02:00
parent 28f6a0d7a7
commit 2f2ed00ca1
11 changed files with 104 additions and 81 deletions

View File

@@ -4,27 +4,7 @@
#include <utility>
#include "RocAuc.h"
namespace platform {
std::vector<int> tensorToVector(const torch::Tensor& tensor)
{
// Ensure the tensor is of type kInt32
if (tensor.dtype() != torch::kInt32) {
throw std::runtime_error("Tensor must be of type kInt32");
}
// Ensure the tensor is contiguous
torch::Tensor contig_tensor = tensor.contiguous();
// Get the number of elements in the tensor
auto num_elements = contig_tensor.numel();
// Get a pointer to the tensor data
const int32_t* tensor_data = contig_tensor.data_ptr<int32_t>();
// Create a std::vector<int> and copy the data
std::vector<int> result(tensor_data, tensor_data + num_elements);
return result;
}
double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels)
{
size_t nClasses = y_proba.size(1);