Add BoostA2DE model and fix some report errors

This commit is contained in:
2024-05-17 01:25:27 +02:00
parent 30a6d5e60d
commit 696c0564a7
9 changed files with 25 additions and 22 deletions

4
.vscode/launch.json vendored
View File

@@ -99,7 +99,9 @@
"request": "launch", "request": "launch",
"program": "${workspaceFolder}/build_debug/src/b_list", "program": "${workspaceFolder}/build_debug/src/b_list",
"args": [ "args": [
"datasets" "results",
"-d",
"mfeat-morphological"
], ],
//"cwd": "/Users/rmontanana/Code/discretizbench", //"cwd": "/Users/rmontanana/Code/discretizbench",
"cwd": "${workspaceFolder}/../discretizbench", "cwd": "${workspaceFolder}/../discretizbench",

View File

@@ -55,6 +55,10 @@ namespace platform {
} }
} }
} }
if (bests.empty()) {
std::cerr << Colors::MAGENTA() << "No results found for model " << model << " and score " << score << Colors::RESET() << std::endl;
exit(1);
}
std::string bestFileName = path + bestResultFile(); std::string bestFileName = path + bestResultFile();
std::ofstream file(bestFileName); std::ofstream file(bestFileName);
file << bests; file << bests;

View File

@@ -37,7 +37,8 @@ void list_results(argparse::ArgumentParser& program)
auto model = program.get<string>("model"); auto model = program.get<string>("model");
auto excel = program.get<bool>("excel"); auto excel = program.get<bool>("excel");
auto report = platform::ResultsDatasetsConsole(); auto report = platform::ResultsDatasetsConsole();
report.report(dataset, score, model); if (!report.report(dataset, score, model))
return;
std::cout << report.getOutput(); std::cout << report.getOutput();
if (excel) { if (excel) {
auto data = report.getData(); auto data = report.getData();
@@ -73,7 +74,7 @@ int main(int argc, char** argv)
} }
throw std::runtime_error("Dataset must be one of " + datasets.toString()); throw std::runtime_error("Dataset must be one of " + datasets.toString());
} }
); );
results_command.add_argument("-m", "--model") results_command.add_argument("-m", "--model")
.help("Model to use: " + platform::Models::instance()->toString() + " or any") .help("Model to use: " + platform::Models::instance()->toString() + " or any")
.default_value("any") .default_value("any")
@@ -86,7 +87,7 @@ int main(int argc, char** argv)
} }
throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString() + " or any"); throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString() + " or any");
} }
); );
results_command.add_argument("--excel").help("Output in Excel format").default_value(false).implicit_value(true); results_command.add_argument("--excel").help("Output in Excel format").default_value(false).implicit_value(true);
results_command.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied"); results_command.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied");

View File

@@ -6,6 +6,7 @@
#include <bayesnet/ensembles/A2DE.h> #include <bayesnet/ensembles/A2DE.h>
#include <bayesnet/ensembles/AODELd.h> #include <bayesnet/ensembles/AODELd.h>
#include <bayesnet/ensembles/BoostAODE.h> #include <bayesnet/ensembles/BoostAODE.h>
#include <bayesnet/ensembles/BoostA2DE.h>
#include <bayesnet/classifiers/TAN.h> #include <bayesnet/classifiers/TAN.h>
#include <bayesnet/classifiers/KDB.h> #include <bayesnet/classifiers/KDB.h>
#include <bayesnet/classifiers/SPODE.h> #include <bayesnet/classifiers/SPODE.h>
@@ -13,6 +14,7 @@
#include <bayesnet/classifiers/TANLd.h> #include <bayesnet/classifiers/TANLd.h>
#include <bayesnet/classifiers/KDBLd.h> #include <bayesnet/classifiers/KDBLd.h>
#include <bayesnet/classifiers/SPODELd.h> #include <bayesnet/classifiers/SPODELd.h>
#include <bayesnet/classifiers/SPODELd.h>
#include <pyclassifiers/STree.h> #include <pyclassifiers/STree.h>
#include <pyclassifiers/ODTE.h> #include <pyclassifiers/ODTE.h>
#include <pyclassifiers/SVC.h> #include <pyclassifiers/SVC.h>

View File

@@ -22,6 +22,8 @@ static platform::Registrar registrarALD("AODELd",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();}); [](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();});
static platform::Registrar registrarBA("BoostAODE", static platform::Registrar registrarBA("BoostAODE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();}); [](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();});
static platform::Registrar registrarBA2("BoostA2DE",
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostA2DE();});
static platform::Registrar registrarSt("STree", static platform::Registrar registrarSt("STree",
[](void) -> bayesnet::BaseClassifier* { return new pywrap::STree();}); [](void) -> bayesnet::BaseClassifier* { return new pywrap::STree();});
static platform::Registrar registrarOdte("Odte", static platform::Registrar registrarOdte("Odte",

View File

@@ -4,15 +4,15 @@
#include "results/ResultsDataset.h" #include "results/ResultsDataset.h"
#include "ResultsDatasetConsole.h" #include "ResultsDatasetConsole.h"
namespace platform { namespace platform {
void ResultsDatasetsConsole::report(const std::string& dataset, const std::string& score, const std::string& model) bool ResultsDatasetsConsole::report(const std::string& dataset, const std::string& score, const std::string& model)
{ {
auto results = platform::ResultsDataset(dataset, model, score); auto results = platform::ResultsDataset(dataset, model, score);
results.load(); results.load();
results.sortModel();
if (results.empty()) { if (results.empty()) {
std::cerr << Colors::RED() << "No results found for dataset " << dataset << " and model " << model << Colors::RESET() << std::endl; std::cerr << Colors::RED() << "No results found for dataset " << dataset << " and model " << model << Colors::RESET() << std::endl;
return; return false;
} }
results.sortModel();
int maxModel = results.maxModelSize(); int maxModel = results.maxModelSize();
int maxHyper = results.maxHyperSize(); int maxHyper = results.maxHyperSize();
double maxResult = results.maxResultScore(); double maxResult = results.maxResultScore();
@@ -76,6 +76,7 @@ namespace platform {
oss << item["hyperparameters"].get<std::string>() << std::endl; oss << item["hyperparameters"].get<std::string>() << std::endl;
body.push_back(oss.str()); body.push_back(oss.str());
} }
return true;
} }
} }

View File

@@ -12,7 +12,7 @@ namespace platform {
public: public:
ResultsDatasetsConsole() = default; ResultsDatasetsConsole() = default;
~ResultsDatasetsConsole() = default; ~ResultsDatasetsConsole() = default;
void report(const std::string& dataset, const std::string& score, const std::string& model); bool report(const std::string& dataset, const std::string& score, const std::string& model);
}; };
} }

View File

@@ -31,6 +31,8 @@ namespace platform {
} }
} }
} }
if (files.empty())
return;
maxModel = std::max(size_t(5), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size()); maxModel = std::max(size_t(5), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size());
maxFile = std::max(size_t(4), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getFilename().size() < b.getFilename().size(); })).getFilename().size()); maxFile = std::max(size_t(4), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getFilename().size() < b.getFilename().size(); })).getFilename().size());
} }

View File

@@ -7,6 +7,7 @@
#include "common/DotEnv.h" #include "common/DotEnv.h"
#include "common/Datasets.h" #include "common/Datasets.h"
#include "common/Paths.h" #include "common/Paths.h"
#include "common/Colors.h"
#include "main/Scores.h" #include "main/Scores.h"
#include "config.h" #include "config.h"
@@ -127,7 +128,7 @@ TEST_CASE("Confusion Matrix JSON", "[Scores]")
REQUIRE(res_json_str["Car"][1] == 2); REQUIRE(res_json_str["Car"][1] == 2);
REQUIRE(res_json_str["Car"][2] == 3); REQUIRE(res_json_str["Car"][2] == 3);
} }
TEST_CASE("Classification Report", "[Scores]") TEST_CASE("Classification Report", "[Scores]") -
{ {
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 }; std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 }; std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
@@ -135,19 +136,7 @@ TEST_CASE("Classification Report", "[Scores]")
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32); auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" }; std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels); platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
std::string expected = R"(Classification Report auto report = scores.classification_report(Colors::BLUE(), "train");
=====================
precision recall f1-score support
========= ========= ========= =========
Aeroplane 0.6666667 0.6666667 0.6666667 3
Boat 0.2500000 1.0000000 0.4000000 1
Car 1.0000000 0.5000000 0.6666667 6
accuracy 0.6000000 10
macro avg 0.6388889 0.7222223 0.5777778 10
weighted avg 0.8250000 0.6000000 0.6400000 10
)";
REQUIRE(scores.classification_report() == expected);
auto json_matrix = scores.get_confusion_matrix_json(true); auto json_matrix = scores.get_confusion_matrix_json(true);
platform::Scores scores2(json_matrix); platform::Scores scores2(json_matrix);
REQUIRE(scores.classification_report() == scores2.classification_report()); REQUIRE(scores.classification_report() == scores2.classification_report());