Add roc-auc-ovr as score to b_main
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
#include "common/Paths.h"
|
||||
#include "Models.h"
|
||||
#include "Scores.h"
|
||||
#include "RocAuc.h"
|
||||
#include "Experiment.h"
|
||||
namespace platform {
|
||||
using json = nlohmann::ordered_json;
|
||||
@@ -86,7 +85,14 @@ namespace platform {
|
||||
return Colors::RESET();
|
||||
}
|
||||
}
|
||||
|
||||
score_t Experiment::parse_score() const
|
||||
{
|
||||
if (result.getScoreName() == "accuracy")
|
||||
return score_t::ACCURACY;
|
||||
if (result.getScoreName() == "roc-auc-ovr")
|
||||
return score_t::ROC_AUC_OVR;
|
||||
throw std::runtime_error("Unknown score: " + result.getScoreName());
|
||||
}
|
||||
void showProgress(int fold, const std::string& color, const std::string& phase)
|
||||
{
|
||||
std::string prefix = phase == "-" ? "" : "\b\b\b\b";
|
||||
@@ -159,10 +165,8 @@ namespace platform {
|
||||
// Initialize results std::vectors
|
||||
//
|
||||
int nResults = nfolds * static_cast<int>(randomSeeds.size());
|
||||
auto accuracy_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto accuracy_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto auc_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto auc_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto score_test = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto score_train = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto train_time = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto test_time = torch::zeros({ nResults }, torch::kFloat64);
|
||||
auto nodes = torch::zeros({ nResults }, torch::kFloat64);
|
||||
@@ -178,6 +182,7 @@ namespace platform {
|
||||
//
|
||||
// Loop over random seeds
|
||||
//
|
||||
auto score = parse_score();
|
||||
for (auto seed : randomSeeds) {
|
||||
if (!quiet) {
|
||||
string prefix = " ";
|
||||
@@ -227,17 +232,14 @@ namespace platform {
|
||||
edges[item] = clf->getNumberOfEdges();
|
||||
num_states[item] = clf->getNumberOfStates();
|
||||
train_time[item] = train_timer.getDuration();
|
||||
double accuracy_train_value = 0.0;
|
||||
double score_train_value = 0.0;
|
||||
//
|
||||
// Score train
|
||||
//
|
||||
double auc_train_value = 0;
|
||||
if (!no_train_score) {
|
||||
auto roc_auc = RocAuc();
|
||||
auto y_proba_train = clf->predict_proba(X_train);
|
||||
Scores scores(y_train, y_proba_train, num_classes, labels);
|
||||
accuracy_train_value = scores.accuracy();
|
||||
auc_train_value = roc_auc.compute(y_proba_train, y_train);
|
||||
score_train_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc();
|
||||
confusion_matrices_train.push_back(scores.get_confusion_matrix_json(true));
|
||||
}
|
||||
//
|
||||
@@ -249,24 +251,18 @@ namespace platform {
|
||||
// auto y_predict = clf->predict(X_test);
|
||||
auto y_proba_test = clf->predict_proba(X_test);
|
||||
Scores scores(y_test, y_proba_test, num_classes, labels);
|
||||
auto accuracy_test_value = scores.accuracy();
|
||||
auto roc_auc = RocAuc();
|
||||
double auc_test_value = roc_auc.compute(y_proba_test, y_test);
|
||||
auto score_test_value = score == score_t::ACCURACY ? scores.accuracy() : scores.auc();
|
||||
test_time[item] = test_timer.getDuration();
|
||||
auc_train[item] = auc_train_value;
|
||||
auc_test[item] = auc_test_value;
|
||||
accuracy_train[item] = accuracy_train_value;
|
||||
accuracy_test[item] = accuracy_test_value;
|
||||
score_train[item] = score_train_value;
|
||||
score_test[item] = score_test_value;
|
||||
confusion_matrices.push_back(scores.get_confusion_matrix_json(true));
|
||||
if (!quiet)
|
||||
std::cout << "\b\b\b, " << flush;
|
||||
//
|
||||
// Store results and times in std::vector
|
||||
//
|
||||
partial_result.addAucTrain(auc_train_value);
|
||||
partial_result.addAucTest(auc_test_value);
|
||||
partial_result.addScoreTrain(accuracy_train_value);
|
||||
partial_result.addScoreTest(accuracy_test_value);
|
||||
partial_result.addScoreTrain(score_train_value);
|
||||
partial_result.addScoreTest(score_test_value);
|
||||
partial_result.addTimeTrain(train_time[item].item<double>());
|
||||
partial_result.addTimeTest(test_time[item].item<double>());
|
||||
item++;
|
||||
@@ -286,10 +282,8 @@ namespace platform {
|
||||
// Store result totals in Result
|
||||
//
|
||||
partial_result.setGraph(graphs);
|
||||
partial_result.setScoreTest(torch::mean(accuracy_test).item<double>()).setScoreTrain(torch::mean(accuracy_train).item<double>());
|
||||
partial_result.setScoreTestStd(torch::std(accuracy_test).item<double>()).setScoreTrainStd(torch::std(accuracy_train).item<double>());
|
||||
partial_result.setAucTest(torch::mean(auc_test).item<double>()).setAucTrain(torch::mean(auc_train).item<double>());
|
||||
partial_result.setAucTestStd(torch::std(auc_test).item<double>()).setAucTrainStd(torch::std(auc_train).item<double>());
|
||||
partial_result.setScoreTest(torch::mean(score_test).item<double>()).setScoreTrain(torch::mean(score_train).item<double>());
|
||||
partial_result.setScoreTestStd(torch::std(score_test).item<double>()).setScoreTrainStd(torch::std(score_train).item<double>());
|
||||
partial_result.setTrainTime(torch::mean(train_time).item<double>()).setTestTime(torch::mean(test_time).item<double>());
|
||||
partial_result.setTestTimeStd(torch::std(test_time).item<double>()).setTrainTimeStd(torch::std(train_time).item<double>());
|
||||
partial_result.setNodes(torch::mean(nodes).item<double>()).setLeaves(torch::mean(edges).item<double>()).setDepth(torch::mean(num_states).item<double>());
|
||||
|
Reference in New Issue
Block a user