From 028522f18041d01af59f8f89963c07015a0ffc8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Fri, 12 Jul 2024 17:41:23 +0200 Subject: [PATCH] Add AUC to reportConsole --- src/reports/ReportConsole.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/reports/ReportConsole.cpp b/src/reports/ReportConsole.cpp index ccc5a87..9d4b410 100644 --- a/src/reports/ReportConsole.cpp +++ b/src/reports/ReportConsole.cpp @@ -65,9 +65,9 @@ namespace platform { maxHyper = std::max(maxHyper, (int)r["hyperparameters"].dump().size()); maxDataset = std::max(maxDataset, (int)r["dataset"].get().size()); } - std::vector header_labels = { " #", "Dataset", "Sampl.", "Feat.", "Cls", nodes_label, leaves_label, depth_label, "Score", "Time", "Hyperparameters" }; + std::vector header_labels = { " #", "Dataset", "Sampl.", "Feat.", "Cls", nodes_label, leaves_label, depth_label, "Score", "ROC-AUC ovr", "Time", "Hyperparameters" }; sheader << Colors::GREEN(); - std::vector header_lengths = { 3, maxDataset, 6, 5, 3, 9, 9, 9, 15, 20, maxHyper }; + std::vector header_lengths = { 3, maxDataset, 6, 5, 3, 9, 9, 9, 15, 15, 20, maxHyper }; for (int i = 0; i < header_labels.size(); i++) { sheader << std::setw(header_lengths[i]) << std::left << header_labels[i] << " "; } @@ -99,6 +99,7 @@ namespace platform { line << std::setw(8) << std::right << std::setprecision(6) << std::fixed << r["score"].get() << "±" << std::setw(6) << std::setprecision(4) << std::fixed << r["score_std"].get(); const std::string status = compareResult(r["dataset"].get(), r["score"].get()); line << status; + line << std::setw(8) << std::right << std::setprecision(6) << std::fixed << r["auc"].get() << "±" << std::setw(6) << std::setprecision(4) << std::fixed << r["auc_std"].get() << " "; line << std::setw(12) << std::right << std::setprecision(6) << std::fixed << r["time"].get() << "±" << std::setw(7) << std::setprecision(4) << std::fixed << r["time_std"].get() << " "; line << r["hyperparameters"].dump(); line << std::endl; @@ -128,6 +129,10 @@ namespace platform { vbody.push_back(line.str()); sbody << line.str(); line.str(""); line << headerLine(fVector("Test scores: ", lastResult["scores_test"], 14, 12)); vbody.push_back(line.str()); sbody << line.str(); + line.str(""); line << headerLine(fVector("Train auc : ", lastResult["aucs_train"], 14, 12)); + vbody.push_back(line.str()); sbody << line.str(); + line.str(""); line << headerLine(fVector("Test auc : ", lastResult["aucs_test"], 14, 12)); + vbody.push_back(line.str()); sbody << line.str(); line.str(""); line << headerLine(fVector("Train times: ", lastResult["times_train"], 10, 3)); vbody.push_back(line.str()); sbody << line.str(); line.str(""); line << headerLine(fVector("Test times: ", lastResult["times_test"], 10, 3));