Add roc-auc-ovr as score to b_main

This commit is contained in:
2024-07-14 12:48:33 +02:00
parent 28f6a0d7a7
commit 2f2ed00ca1
11 changed files with 104 additions and 81 deletions

View File

@@ -3,7 +3,6 @@
#include "common/Paths.h"
#include "Models.h"
#include "Scores.h"
#include "RocAuc.h"
#include "Experiment.h"
namespace platform {
using json = nlohmann::ordered_json;
@@ -86,7 +85,14 @@ namespace platform {
return Colors::RESET();
}
}
score_t Experiment::parse_score() const
{
if (result.getScoreName() == "accuracy")
return score_t::ACCURACY;
if (result.getScoreName() == "roc-auc-ovr")
return score_t::ROC_AUC_OVR;
throw std::runtime_error("Unknown score: " + result.getScoreName());
}
void showProgress(int fold, const std::string& color, const std::string& phase)
{
std::string prefix = phase == "-" ? "" : "\b\b\b\b";
@@ -159,10 +165,8 @@ namespace platform {
// Initialize results std::vectors
//
int nResults = nfolds * static_cast<int>(randomSeeds.size());
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
auto auc_test = torch::zeros({ nResults }, torch::kFloat64);
auto auc_train = torch::zeros({ nResults }, torch::kFloat64);
auto score_test = torch::zeros({ nResults }, torch::kFloat64);
auto score_train = torch::zeros({ nResults }, torch::kFloat64);
auto train_time = torch::zeros({ nResults }, torch::kFloat64);
auto test_time = torch::zeros({ nResults }, torch::kFloat64);
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
@@ -178,6 +182,7 @@ namespace platform {
//
// Loop over random seeds
//
auto score = parse_score();
for (auto seed : randomSeeds) {
if (!quiet) {
string prefix = " ";
@@ -227,17 +232,14 @@ namespace platform {
edges[item] = clf->getNumberOfEdges();
num_states[item] = clf->getNumberOfStates();
train_time[item] = train_timer.getDuration();
double accuracy_train_value = 0.0;
double score_train_value = 0.0;
//
// Score train
//
double auc_train_value = 0;
if (!no_train_score) {
auto roc_auc = RocAuc();
auto y_proba_train = clf->predict_proba(X_train);
Scores scores(y_train, y_proba_train, num_classes, labels);
accuracy_train_value = scores.accuracy();
auc_train_value = roc_auc.compute(y_proba_train, y_train);
score_train_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc();
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
}
//
@@ -249,24 +251,18 @@ namespace platform {
// auto y_predict = clf->predict(X_test);
auto y_proba_test = clf->predict_proba(X_test);
Scores scores(y_test, y_proba_test, num_classes, labels);
auto accuracy_test_value = scores.accuracy();
auto roc_auc = RocAuc();
double auc_test_value = roc_auc.compute(y_proba_test, y_test);
auto score_test_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc();
test_time[item] = test_timer.getDuration();
auc_train[item] = auc_train_value;
auc_test[item] = auc_test_value;
accuracy_train[item] = accuracy_train_value;
accuracy_test[item] = accuracy_test_value;
score_train[item] = score_train_value;
score_test[item] = score_test_value;
confusion_matrices.push_back(scores.get_confusion_matrix_json(true));
if (!quiet)
std::cout << "\b\b\b, " << flush;
//
// Store results and times in std::vector
//
partial_result.addAucTrain(auc_train_value);
partial_result.addAucTest(auc_test_value);
partial_result.addScoreTrain(accuracy_train_value);
partial_result.addScoreTest(accuracy_test_value);
partial_result.addScoreTrain(score_train_value);
partial_result.addScoreTest(score_test_value);
partial_result.addTimeTrain(train_time[item].item<double>());
partial_result.addTimeTest(test_time[item].item<double>());
item++;
@@ -286,10 +282,8 @@ namespace platform {
// Store result totals in Result
//
partial_result.setGraph(graphs);
partial_result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
partial_result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
partial_result.setAucTest(torch::mean(auc_test).item<double>()).setAucTrain(torch::mean(auc_train).item<double>());
partial_result.setAucTestStd(torch::std(auc_test).item<double>()).setAucTrainStd(torch::std(auc_train).item<double>());
partial_result.setScoreTest(torch::mean(score_test).item<double>()).setScoreTrain(torch::mean(score_train).item<double>());
partial_result.setScoreTestStd(torch::std(score_test).item<double>()).setScoreTrainStd(torch::std(score_train).item<double>());
partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
partial_result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
partial_result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());