Add BoostA2DE model and fix some report errors
This commit is contained in:
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -99,7 +99,9 @@
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/build_debug/src/b_list",
|
||||
"args": [
|
||||
"datasets"
|
||||
"results",
|
||||
"-d",
|
||||
"mfeat-morphological"
|
||||
],
|
||||
//"cwd": "/Users/rmontanana/Code/discretizbench",
|
||||
"cwd": "${workspaceFolder}/../discretizbench",
|
||||
|
@@ -55,6 +55,10 @@ namespace platform {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (bests.empty()) {
|
||||
std::cerr << Colors::MAGENTA() << "No results found for model " << model << " and score " << score << Colors::RESET() << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
std::string bestFileName = path + bestResultFile();
|
||||
std::ofstream file(bestFileName);
|
||||
file << bests;
|
||||
|
@@ -37,7 +37,8 @@ void list_results(argparse::ArgumentParser& program)
|
||||
auto model = program.get<string>("model");
|
||||
auto excel = program.get<bool>("excel");
|
||||
auto report = platform::ResultsDatasetsConsole();
|
||||
report.report(dataset, score, model);
|
||||
if (!report.report(dataset, score, model))
|
||||
return;
|
||||
std::cout << report.getOutput();
|
||||
if (excel) {
|
||||
auto data = report.getData();
|
||||
@@ -73,7 +74,7 @@ int main(int argc, char** argv)
|
||||
}
|
||||
throw std::runtime_error("Dataset must be one of " + datasets.toString());
|
||||
}
|
||||
);
|
||||
);
|
||||
results_command.add_argument("-m", "--model")
|
||||
.help("Model to use: " + platform::Models::instance()->toString() + " or any")
|
||||
.default_value("any")
|
||||
@@ -86,7 +87,7 @@ int main(int argc, char** argv)
|
||||
}
|
||||
throw std::runtime_error("Model must be one of " + platform::Models::instance()->toString() + " or any");
|
||||
}
|
||||
);
|
||||
);
|
||||
results_command.add_argument("--excel").help("Output in Excel format").default_value(false).implicit_value(true);
|
||||
results_command.add_argument("-s", "--score").default_value("accuracy").help("Filter results of the score name supplied");
|
||||
|
||||
|
@@ -6,6 +6,7 @@
|
||||
#include <bayesnet/ensembles/A2DE.h>
|
||||
#include <bayesnet/ensembles/AODELd.h>
|
||||
#include <bayesnet/ensembles/BoostAODE.h>
|
||||
#include <bayesnet/ensembles/BoostA2DE.h>
|
||||
#include <bayesnet/classifiers/TAN.h>
|
||||
#include <bayesnet/classifiers/KDB.h>
|
||||
#include <bayesnet/classifiers/SPODE.h>
|
||||
@@ -13,6 +14,7 @@
|
||||
#include <bayesnet/classifiers/TANLd.h>
|
||||
#include <bayesnet/classifiers/KDBLd.h>
|
||||
#include <bayesnet/classifiers/SPODELd.h>
|
||||
#include <bayesnet/classifiers/SPODELd.h>
|
||||
#include <pyclassifiers/STree.h>
|
||||
#include <pyclassifiers/ODTE.h>
|
||||
#include <pyclassifiers/SVC.h>
|
||||
|
@@ -22,6 +22,8 @@ static platform::Registrar registrarALD("AODELd",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::AODELd();});
|
||||
static platform::Registrar registrarBA("BoostAODE",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostAODE();});
|
||||
static platform::Registrar registrarBA2("BoostA2DE",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new bayesnet::BoostA2DE();});
|
||||
static platform::Registrar registrarSt("STree",
|
||||
[](void) -> bayesnet::BaseClassifier* { return new pywrap::STree();});
|
||||
static platform::Registrar registrarOdte("Odte",
|
||||
|
@@ -4,15 +4,15 @@
|
||||
#include "results/ResultsDataset.h"
|
||||
#include "ResultsDatasetConsole.h"
|
||||
namespace platform {
|
||||
void ResultsDatasetsConsole::report(const std::string& dataset, const std::string& score, const std::string& model)
|
||||
bool ResultsDatasetsConsole::report(const std::string& dataset, const std::string& score, const std::string& model)
|
||||
{
|
||||
auto results = platform::ResultsDataset(dataset, model, score);
|
||||
results.load();
|
||||
results.sortModel();
|
||||
if (results.empty()) {
|
||||
std::cerr << Colors::RED() << "No results found for dataset " << dataset << " and model " << model << Colors::RESET() << std::endl;
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
results.sortModel();
|
||||
int maxModel = results.maxModelSize();
|
||||
int maxHyper = results.maxHyperSize();
|
||||
double maxResult = results.maxResultScore();
|
||||
@@ -76,6 +76,7 @@ namespace platform {
|
||||
oss << item["hyperparameters"].get<std::string>() << std::endl;
|
||||
body.push_back(oss.str());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -12,7 +12,7 @@ namespace platform {
|
||||
public:
|
||||
ResultsDatasetsConsole() = default;
|
||||
~ResultsDatasetsConsole() = default;
|
||||
void report(const std::string& dataset, const std::string& score, const std::string& model);
|
||||
bool report(const std::string& dataset, const std::string& score, const std::string& model);
|
||||
};
|
||||
}
|
||||
|
||||
|
@@ -31,6 +31,8 @@ namespace platform {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (files.empty())
|
||||
return;
|
||||
maxModel = std::max(size_t(5), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getModel().size() < b.getModel().size(); })).getModel().size());
|
||||
maxFile = std::max(size_t(4), (*max_element(files.begin(), files.end(), [](const Result& a, const Result& b) { return a.getFilename().size() < b.getFilename().size(); })).getFilename().size());
|
||||
}
|
||||
|
@@ -7,6 +7,7 @@
|
||||
#include "common/DotEnv.h"
|
||||
#include "common/Datasets.h"
|
||||
#include "common/Paths.h"
|
||||
#include "common/Colors.h"
|
||||
#include "main/Scores.h"
|
||||
#include "config.h"
|
||||
|
||||
@@ -127,7 +128,7 @@ TEST_CASE("Confusion Matrix JSON", "[Scores]")
|
||||
REQUIRE(res_json_str["Car"][1] == 2);
|
||||
REQUIRE(res_json_str["Car"][2] == 3);
|
||||
}
|
||||
TEST_CASE("Classification Report", "[Scores]")
|
||||
TEST_CASE("Classification Report", "[Scores]") -
|
||||
{
|
||||
std::vector<int> y_test = { 0, 2, 2, 2, 2, 0, 1, 2, 0, 2 };
|
||||
std::vector<int> y_pred = { 0, 1, 2, 2, 1, 1, 1, 0, 0, 2 };
|
||||
@@ -135,19 +136,7 @@ TEST_CASE("Classification Report", "[Scores]")
|
||||
auto y_pred_tensor = torch::tensor(y_pred, torch::kInt32);
|
||||
std::vector<std::string> labels = { "Aeroplane", "Boat", "Car" };
|
||||
platform::Scores scores(y_test_tensor, y_pred_tensor, 3, labels);
|
||||
std::string expected = R"(Classification Report
|
||||
=====================
|
||||
precision recall f1-score support
|
||||
========= ========= ========= =========
|
||||
Aeroplane 0.6666667 0.6666667 0.6666667 3
|
||||
Boat 0.2500000 1.0000000 0.4000000 1
|
||||
Car 1.0000000 0.5000000 0.6666667 6
|
||||
|
||||
accuracy 0.6000000 10
|
||||
macro avg 0.6388889 0.7222223 0.5777778 10
|
||||
weighted avg 0.8250000 0.6000000 0.6400000 10
|
||||
)";
|
||||
REQUIRE(scores.classification_report() == expected);
|
||||
auto report = scores.classification_report(Colors::BLUE(), "train");
|
||||
auto json_matrix = scores.get_confusion_matrix_json(true);
|
||||
platform::Scores scores2(json_matrix);
|
||||
REQUIRE(scores.classification_report() == scores2.classification_report());
|
||||
|
Reference in New Issue
Block a user