Add BoostA2DE model and fix some report errors

This commit is contained in:
2024-05-17 01:25:27 +02:00
parent 30a6d5e60d
commit 696c0564a7
9 changed files with 25 additions and 22 deletions

4
.vscode/launch.json vendored
View File

@@ -99,7 +99,9 @@
"request": "launch",
"program": "${workspaceFolder}/build_debug/src/b_list",
"args": [
"datasets"
"results",
"-d",
"mfeat-morphological"
],
//"cwd": "/Users/rmontanana/Code/discretizbench",
"cwd": "${workspaceFolder}/../discretizbench",

View File

@@ -55,6 +55,10 @@ namespace platform {
}
}
}
if (bests.empty()) {
std::cerr << Colors::MAGENTA() << "No results found for model " << model << " and score " << score << Colors::RESET() << std::endl;
exit(1);
}
std::string bestFileName = path + bestResultFile();
std::ofstream file(bestFileName);
file << bests;

View File

@@ -37,7 +37,8 @@ void list_results(argparse::ArgumentParser& program)
auto model = program.get<string>("model");
auto excel = program.get<bool>("excel");
auto report = platform::ResultsDatasetsConsole();
report.report(dataset, score, model);
if (!report.report(dataset, score, model))
return;
std::cout << report.getOutput();
if (excel) {
auto data = report.getData();

View File

@@ -6,6 +6,7 @@
#include <bayesnet/ensembles/A2DE.h>
#include <bayesnet/ensembles/AODELd.h>
#include <bayesnet/ensembles/BoostAODE.h>
#include <bayesnet/ensembles/BoostA2DE.h>
#include <bayesnet/classifiers/TAN.h>
#include <bayesnet/classifiers/KDB.h>
#include <bayesnet/classifiers/SPODE.h>
@@ -13,6 +14,7 @@
#include <bayesnet/classifiers/TANLd.h>
#include <bayesnet/classifiers/KDBLd.h>
#include <bayesnet/classifiers/SPODELd.h>
#include <bayesnet/classifiers/SPODELd.h>
#include <pyclassifiers/STree.h>
#include <pyclassifiers/ODTE.h>
#include <pyclassifiers/SVC.h>

View File

@@ -22,6 +22,8 @@ static platform::Registrar registrarALD("AODELd",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();});
static platform::Registrar registrarBA("BoostAODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();});
static platform::Registrar registrarBA2("BoostA2DE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostA2DE();});
static platform::Registrar registrarSt("STree",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::STree();});
static platform::Registrar registrarOdte("Odte",

View File

@@ -4,15 +4,15 @@
#include "results/ResultsDataset.h"
#include "ResultsDatasetConsole.h"
namespace platform {
void ResultsDatasetsConsole::report(const std::string& dataset, const std::string& score, const std::string& model)
bool ResultsDatasetsConsole::report(const std::string& dataset, const std::string& score, const std::string& model)
{
auto results = platform::ResultsDataset(dataset, model, score);
results.load();
results.sortModel();
if (results.empty()) {
std::cerr << Colors::RED() << "No results found for dataset " << dataset << " and model " << model << Colors::RESET() << std::endl;
return;
return false;
}
results.sortModel();
int maxModel = results.maxModelSize();
int maxHyper = results.maxHyperSize();
double maxResult = results.maxResultScore();
@@ -76,6 +76,7 @@ namespace platform {
oss << item["hyperparameters"].get<std::string>() << std::endl;
body.push_back(oss.str());
}
return true;
}
}

View File

@@ -12,7 +12,7 @@ namespace platform {
public:
ResultsDatasetsConsole() = default;
~ResultsDatasetsConsole() = default;
void report(const std::string& dataset, const std::string& score, const std::string& model);
bool report(const std::string& dataset, const std::string& score, const std::string& model);
};
}

View File

@@ -31,6 +31,8 @@ namespace platform {
}
}
}
if (files.empty())
return;
maxModel = std::max(size_t(5), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size());
maxFile = std::max(size_t(4), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getFilename().size() < b.getFilename().size(); })).getFilename().size());
}

View File

@@ -7,6 +7,7 @@
#include "common/DotEnv.h"
#include "common/Datasets.h"
#include "common/Paths.h"
#include "common/Colors.h"
#include "main/Scores.h"
#include "config.h"
@@ -127,7 +128,7 @@ TEST_CASE("Confusion Matrix JSON", "[Scores]")
REQUIRE(res_json_str["Car"][1] == 2);
REQUIRE(res_json_str["Car"][2] == 3);
}
TEST_CASE("Classification Report", "[Scores]")
TEST_CASE("Classification Report", "[Scores]") -
{
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
@@ -135,19 +136,7 @@ TEST_CASE("Classification Report", "[Scores]")
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
std::string expected = R"(Classification Report
=====================
precision recall f1-score support
========= ========= ========= =========
Aeroplane 0.6666667 0.6666667 0.6666667 3
Boat 0.2500000 1.0000000 0.4000000 1
Car 1.0000000 0.5000000 0.6666667 6
accuracy 0.6000000 10
macro avg 0.6388889 0.7222223 0.5777778 10
weighted avg 0.8250000 0.6000000 0.6400000 10
)";
REQUIRE(scores.classification_report() == expected);
auto report = scores.classification_report(Colors::BLUE(), "train");
auto json_matrix = scores.get_confusion_matrix_json(true);
platform::Scores scores2(json_matrix);
REQUIRE(scores.classification_report() == scores2.classification_report());