Add roc-auc-ovr as score to b_main
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
#include "common/Paths.h"
|
||||
#include "Models.h"
|
||||
#include "Scores.h"
|
||||
#include "RocAuc.h"
|
||||
#include "Experiment.h"
|
||||
namespace platform {
|
||||
using json = nlohmann::ordered_json;
|
||||
@@ -86,7 +85,14 @@ namespace platform {
|
||||
return Colors::RESET();
|
||||
}
|
||||
}
|
||||
|
||||
score_t Experiment::parse_score() const
|
||||
{
|
||||
if (result.getScoreName() == "accuracy")
|
||||
return score_t::ACCURACY;
|
||||
if (result.getScoreName() == "roc-auc-ovr")
|
||||
return score_t::ROC_AUC_OVR;
|
||||
throw std::runtime_error("Unknown score: " + result.getScoreName());
|
||||
}
|
||||
void showProgress(int fold, const std::string& color, const std::string& phase)
|
||||
{
|
||||
std::string prefix = phase == "-" ? "" : "\b\b\b\b";
|
||||
@@ -159,10 +165,8 @@ namespace platform {
|
||||
// Initialize results std::vectors
|
||||
//
|
||||
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
||||
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto auc_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto auc_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto score_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto score_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto train_time = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto test_time = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
|
||||
@@ -178,6 +182,7 @@ namespace platform {
|
||||
//
|
||||
// Loop over random seeds
|
||||
//
|
||||
auto score = parse_score();
|
||||
for (auto seed : randomSeeds) {
|
||||
if (!quiet) {
|
||||
string prefix = " ";
|
||||
@@ -227,17 +232,14 @@ namespace platform {
|
||||
edges[item] = clf->getNumberOfEdges();
|
||||
num_states[item] = clf->getNumberOfStates();
|
||||
train_time[item] = train_timer.getDuration();
|
||||
double accuracy_train_value = 0.0;
|
||||
double score_train_value = 0.0;
|
||||
//
|
||||
// Score train
|
||||
//
|
||||
double auc_train_value = 0;
|
||||
if (!no_train_score) {
|
||||
auto roc_auc = RocAuc();
|
||||
auto y_proba_train = clf->predict_proba(X_train);
|
||||
Scores scores(y_train, y_proba_train, num_classes, labels);
|
||||
accuracy_train_value = scores.accuracy();
|
||||
auc_train_value = roc_auc.compute(y_proba_train, y_train);
|
||||
score_train_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc();
|
||||
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
|
||||
}
|
||||
//
|
||||
@@ -249,24 +251,18 @@ namespace platform {
|
||||
// auto y_predict = clf->predict(X_test);
|
||||
auto y_proba_test = clf->predict_proba(X_test);
|
||||
Scores scores(y_test, y_proba_test, num_classes, labels);
|
||||
auto accuracy_test_value = scores.accuracy();
|
||||
auto roc_auc = RocAuc();
|
||||
double auc_test_value = roc_auc.compute(y_proba_test, y_test);
|
||||
auto score_test_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc();
|
||||
test_time[item] = test_timer.getDuration();
|
||||
auc_train[item] = auc_train_value;
|
||||
auc_test[item] = auc_test_value;
|
||||
accuracy_train[item] = accuracy_train_value;
|
||||
accuracy_test[item] = accuracy_test_value;
|
||||
score_train[item] = score_train_value;
|
||||
score_test[item] = score_test_value;
|
||||
confusion_matrices.push_back(scores.get_confusion_matrix_json(true));
|
||||
if (!quiet)
|
||||
std::cout << "\b\b\b, " << flush;
|
||||
//
|
||||
// Store results and times in std::vector
|
||||
//
|
||||
partial_result.addAucTrain(auc_train_value);
|
||||
partial_result.addAucTest(auc_test_value);
|
||||
partial_result.addScoreTrain(accuracy_train_value);
|
||||
partial_result.addScoreTest(accuracy_test_value);
|
||||
partial_result.addScoreTrain(score_train_value);
|
||||
partial_result.addScoreTest(score_test_value);
|
||||
partial_result.addTimeTrain(train_time[item].item<double>());
|
||||
partial_result.addTimeTest(test_time[item].item<double>());
|
||||
item++;
|
||||
@@ -286,10 +282,8 @@ namespace platform {
|
||||
// Store result totals in Result
|
||||
//
|
||||
partial_result.setGraph(graphs);
|
||||
partial_result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
|
||||
partial_result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
|
||||
partial_result.setAucTest(torch::mean(auc_test).item<double>()).setAucTrain(torch::mean(auc_train).item<double>());
|
||||
partial_result.setAucTestStd(torch::std(auc_test).item<double>()).setAucTrainStd(torch::std(auc_train).item<double>());
|
||||
partial_result.setScoreTest(torch::mean(score_test).item<double>()).setScoreTrain(torch::mean(score_train).item<double>());
|
||||
partial_result.setScoreTestStd(torch::std(score_test).item<double>()).setScoreTrainStd(torch::std(score_train).item<double>());
|
||||
partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
||||
partial_result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
|
||||
partial_result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
||||
|
@@ -11,7 +11,7 @@
|
||||
|
||||
namespace platform {
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum class score_t { NONE, ACCURACY, ROC_AUC_OVR };
|
||||
class Experiment {
|
||||
public:
|
||||
Experiment() = default;
|
||||
@@ -55,6 +55,7 @@ namespace platform {
|
||||
void saveGraph();
|
||||
void report(bool classification_report = false);
|
||||
private:
|
||||
score_t parse_score() const;
|
||||
Result result;
|
||||
bool discretized{ false }, stratified{ false };
|
||||
std::vector<PartialResult> results;
|
||||
|
@@ -44,10 +44,6 @@ namespace platform {
|
||||
PartialResult& setScoreTrainStd(double score_std) { data["score_train_std"] = score_std; return *this; }
|
||||
PartialResult& setScoreTest(double score) { data["score"] = score; return *this; }
|
||||
PartialResult& setScoreTestStd(double score_std) { data["score_std"] = score_std; return *this; }
|
||||
PartialResult& setAucTrain(double score) { data["auc_train"] = score; return *this; }
|
||||
PartialResult& setAucTrainStd(double score_std) { data["auc_train_std"] = score_std; return *this; }
|
||||
PartialResult& setAucTest(double score) { data["auc"] = score; return *this; }
|
||||
PartialResult& setAucTestStd(double score_std) { data["auc_std"] = score_std; return *this; }
|
||||
PartialResult& setTrainTime(double train_time)
|
||||
{
|
||||
data["train_time"] = train_time;
|
||||
@@ -75,8 +71,6 @@ namespace platform {
|
||||
PartialResult& setNodes(float nodes) { data["nodes"] = nodes; return *this; }
|
||||
PartialResult& setLeaves(float leaves) { data["leaves"] = leaves; return *this; }
|
||||
PartialResult& setDepth(float depth) { data["depth"] = depth; return *this; }
|
||||
PartialResult& addAucTrain(double score) { data["aucs_train"].push_back(score); return *this; }
|
||||
PartialResult& addAucTest(double score) { data["aucs_test"].push_back(score); return *this; }
|
||||
PartialResult& addScoreTrain(double score) { data["scores_train"].push_back(score); return *this; }
|
||||
PartialResult& addScoreTest(double score) { data["scores_test"].push_back(score); return *this; }
|
||||
PartialResult& addTimeTrain(double time) { data["times_train"].push_back(time); return *this; }
|
||||
|
@@ -4,27 +4,7 @@
|
||||
#include <utility>
|
||||
#include "RocAuc.h"
|
||||
namespace platform {
|
||||
std::vector<int> tensorToVector(const torch::Tensor& tensor)
|
||||
{
|
||||
// Ensure the tensor is of type kInt32
|
||||
if (tensor.dtype() != torch::kInt32) {
|
||||
throw std::runtime_error("Tensor must be of type kInt32");
|
||||
}
|
||||
|
||||
// Ensure the tensor is contiguous
|
||||
torch::Tensor contig_tensor = tensor.contiguous();
|
||||
|
||||
// Get the number of elements in the tensor
|
||||
auto num_elements = contig_tensor.numel();
|
||||
|
||||
// Get a pointer to the tensor data
|
||||
const int32_t* tensor_data = contig_tensor.data_ptr<int32_t>();
|
||||
|
||||
// Create a std::vector<int> and copy the data
|
||||
std::vector<int> result(tensor_data, tensor_data + num_elements);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
double RocAuc::compute(const torch::Tensor& y_proba, const torch::Tensor& labels)
|
||||
{
|
||||
size_t nClasses = y_proba.size(1);
|
||||
|
@@ -1,8 +1,9 @@
|
||||
#include <sstream>
|
||||
#include "Scores.h"
|
||||
#include "common/Utils.h" // tensorToVector
|
||||
#include "common/Colors.h"
|
||||
namespace platform {
|
||||
Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_proba, int num_classes, std::vector<std::string> labels) : num_classes(num_classes), labels(labels)
|
||||
Scores::Scores(torch::Tensor& y_test, torch::Tensor& y_proba, int num_classes, std::vector<std::string> labels) : num_classes(num_classes), labels(labels), y_test(y_test), y_proba(y_proba)
|
||||
{
|
||||
if (labels.size() == 0) {
|
||||
init_default_labels();
|
||||
@@ -41,6 +42,44 @@ namespace platform {
|
||||
}
|
||||
compute_accuracy_value();
|
||||
}
|
||||
float Scores::auc()
|
||||
{
|
||||
size_t nSamples = y_test.numel();
|
||||
if (nSamples == 0) return 0;
|
||||
// In binary classification problem there's no need to calculate the average of the AUCs
|
||||
auto nClasses = num_classes;
|
||||
if (num_classes == 2)
|
||||
nClasses = 1;
|
||||
auto y_testv = tensorToVector<int>(y_test);
|
||||
std::vector<double> aucScores(nClasses, 0.0);
|
||||
std::vector<std::pair<double, int>> scoresAndLabels;
|
||||
for (size_t classIdx = 0; classIdx < nClasses; ++classIdx) {
|
||||
scoresAndLabels.clear();
|
||||
for (size_t i = 0; i < nSamples; ++i) {
|
||||
scoresAndLabels.emplace_back(y_proba[i][classIdx].item<float>(), y_testv[i] == classIdx ? 1 : 0);
|
||||
}
|
||||
std::sort(scoresAndLabels.begin(), scoresAndLabels.end(), std::greater<>());
|
||||
std::vector<double> tpr, fpr;
|
||||
double tp = 0, fp = 0;
|
||||
double totalPos = std::count(y_testv.begin(), y_testv.end(), classIdx);
|
||||
double totalNeg = nSamples - totalPos;
|
||||
for (const auto& [score, label] : scoresAndLabels) {
|
||||
if (label == 1) {
|
||||
tp += 1;
|
||||
} else {
|
||||
fp += 1;
|
||||
}
|
||||
tpr.push_back(tp / totalPos);
|
||||
fpr.push_back(fp / totalNeg);
|
||||
}
|
||||
double auc = 0.0;
|
||||
for (size_t i = 1; i < tpr.size(); ++i) {
|
||||
auc += 0.5 * (fpr[i] - fpr[i - 1]) * (tpr[i] + tpr[i - 1]);
|
||||
}
|
||||
aucScores[classIdx] = auc;
|
||||
}
|
||||
return std::accumulate(aucScores.begin(), aucScores.end(), 0.0) / nClasses;
|
||||
}
|
||||
Scores Scores::create_aggregate(const json& data, const std::string key)
|
||||
{
|
||||
auto scores = Scores(data[key][0]);
|
||||
|
@@ -9,10 +9,11 @@ namespace platform {
|
||||
using json = nlohmann::ordered_json;
|
||||
class Scores {
|
||||
public:
|
||||
Scores(torch::Tensor& y_test, torch::Tensor& y_pred, int num_classes, std::vector<std::string> labels = {});
|
||||
Scores(torch::Tensor& y_test, torch::Tensor& y_proba, int num_classes, std::vector<std::string> labels = {});
|
||||
explicit Scores(const json& confusion_matrix_);
|
||||
static Scores create_aggregate(const json& data, const std::string key);
|
||||
float accuracy();
|
||||
float auc();
|
||||
float f1_score(int num_class);
|
||||
float f1_weighted();
|
||||
float f1_macro();
|
||||
@@ -34,6 +35,9 @@ namespace platform {
|
||||
int total;
|
||||
std::vector<std::string> labels;
|
||||
torch::Tensor confusion_matrix; // Rows ar actual, columns are predicted
|
||||
torch::Tensor null_t; // Covenient null tensor needed when confusion_matrix constructor is used
|
||||
torch::Tensor& y_test = null_t; // for ROC AUC
|
||||
torch::Tensor& y_proba = null_t; // for ROC AUC
|
||||
int label_len = 16;
|
||||
int dlen = 9;
|
||||
int ndec = 7;
|
||||
|
Reference in New Issue
Block a user