Complete predict & predict_proba in ensemble

This commit is contained in:
Ricardo Montañana Gómez 2024-02-24 18:36:09 +01:00
parent 8477698d8d
commit 02e456befb
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
9 changed files with 104 additions and 101 deletions

View File

@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- predict_proba method in Classifier
- predict_proba method in BoostAODE
- predict_voting parameter in BoostAODE constructor to use voting or probability to predict (default is voting)
- hyperparameter predict_voting to AODE, AODELd and BoostAODE (Ensemble child classes)
- tests to check predict & predict_proba coherence
## [1.0.2] - 2024-02-20

View File

@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.20)
project(BayesNet
VERSION 1.0.2
VERSION 1.0.3
DESCRIPTION "Bayesian Network and basic classifiers Library."
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
LANGUAGES CXX

View File

@ -1,10 +1,26 @@
#include "AODE.h"
namespace bayesnet {
AODE::AODE() : Ensemble() {}
AODE::AODE(bool predict_voting) : Ensemble(predict_voting)
{
validHyperparameters = { "predict_voting" };
}
void AODE::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("predict_voting")) {
predict_voting = hyperparameters["predict_voting"];
hyperparameters.erase("predict_voting");
}
if (!hyperparameters.empty()) {
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
}
void AODE::buildModel(const torch::Tensor& weights)
{
models.clear();
significanceModels.clear();
for (int i = 0; i < features.size(); ++i) {
models.push_back(std::make_unique<SPODE>(i));
}

View File

@ -4,12 +4,13 @@
#include "SPODE.h"
namespace bayesnet {
class AODE : public Ensemble {
public:
AODE(bool predict_voting = true);
virtual ~AODE() {};
void setHyperparameters(const nlohmann::json& hyperparameters) override;
std::vector<std::string> graph(const std::string& title = "AODE") const override;
protected:
void buildModel(const torch::Tensor& weights) override;
public:
AODE();
virtual ~AODE() {};
std::vector<std::string> graph(const std::string& title = "AODE") const override;
};
}
#endif

View File

@ -1,7 +1,22 @@
#include "AODELd.h"
namespace bayesnet {
AODELd::AODELd() : Ensemble(), Proposal(dataset, features, className) {}
AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className)
{
validHyperparameters = { "predict_voting" };
}
void AODELd::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("predict_voting")) {
predict_voting = hyperparameters["predict_voting"];
hyperparameters.erase("predict_voting");
}
if (!hyperparameters.empty()) {
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
}
}
AODELd& AODELd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
{
checkInput(X_, y_);

View File

@ -6,15 +6,15 @@
namespace bayesnet {
class AODELd : public Ensemble, public Proposal {
public:
AODELd(bool predict_voting = true);
virtual ~AODELd() = default;
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_) override;
void setHyperparameters(const nlohmann::json& hyperparameters) override;
std::vector<std::string> graph(const std::string& name = "AODELd") const override;
protected:
void trainModel(const torch::Tensor& weights) override;
void buildModel(const torch::Tensor& weights) override;
public:
AODELd();
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_) override;
virtual ~AODELd() = default;
std::vector<std::string> graph(const std::string& name = "AODELd") const override;
static inline std::string version() { return "0.0.1"; };
};
}
#endif // !AODELD_H

View File

@ -10,13 +10,14 @@
namespace bayesnet {
BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)
{
validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features", "tolerance" };
validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features", "tolerance", "predict_voting" };
}
void BoostAODE::buildModel(const torch::Tensor& weights)
{
// Models shall be built in trainModel
models.clear();
significanceModels.clear();
n_models = 0;
// Prepare the validation dataset
auto y_ = dataset.index({ -1, "..." });
@ -72,6 +73,10 @@ namespace bayesnet {
tolerance = hyperparameters["tolerance"];
hyperparameters.erase("tolerance");
}
if (hyperparameters.contains("predict_voting")) {
predict_voting = hyperparameters["predict_voting"];
hyperparameters.erase("predict_voting");
}
if (hyperparameters.contains("select_features")) {
auto selectedAlgorithm = hyperparameters["select_features"];
std::vector<std::string> algos = { "IWSS", "FCBF", "CFS" };
@ -128,8 +133,11 @@ namespace bayesnet {
if (selectFeatures) {
featuresUsed = initializeModels();
}
if (maxModels == 0)
bool resetMaxModels = false;
if (maxModels == 0) {
maxModels = .1 * n > 10 ? .1 * n : n;
resetMaxModels = true; // Flag to unset maxModels
}
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
bool exitCondition = false;
// Variables to control the accuracy finish condition
@ -211,6 +219,9 @@ namespace bayesnet {
status = WARNING;
}
notes.push_back("Number of models: " + std::to_string(n_models));
if (resetMaxModels) {
maxModels = 0;
}
}
std::vector<std::string> BoostAODE::graph(const std::string& title) const
{

View File

@ -7,7 +7,7 @@
namespace bayesnet {
class BoostAODE : public Ensemble {
public:
BoostAODE(bool predict_voting = false);
BoostAODE(bool predict_voting = true);
virtual ~BoostAODE() = default;
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
void setHyperparameters(const nlohmann::json& hyperparameters) override;

View File

@ -2,9 +2,6 @@
#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_approx.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <vector>
#include <map>
#include <string>
#include "KDB.h"
#include "TAN.h"
#include "SPODE.h"
@ -16,12 +13,9 @@
#include "AODELd.h"
#include "TestUtils.h"
TEST_CASE("Library check version", "[BayesNet]")
{
auto clf = bayesnet::KDB(2);
REQUIRE(clf.getVersion() == "1.0.2");
}
TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
const std::string ACTUAL_VERSION = "1.0.3";
TEST_CASE("Test Bayesian Classifiers score & version", "[BayesNet]")
{
map <pair<std::string, std::string>, float> scores = {
// Diabetes
@ -37,87 +31,34 @@ TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
{{"iris", "AODE"}, 0.973333}, {{"iris", "KDB"}, 0.973333}, {{"iris", "SPODE"}, 0.973333}, {{"iris", "TAN"}, 0.973333},
{{"iris", "AODELd"}, 0.973333}, {{"iris", "KDBLd"}, 0.973333}, {{"iris", "SPODELd"}, 0.96f}, {{"iris", "TANLd"}, 0.97333f}, {{"iris", "BoostAODE"}, 0.98f}
};
std::map<std::string, bayesnet::BaseClassifier*> models = {
{"AODE", new bayesnet::AODE()}, {"AODELd", new bayesnet::AODELd()},
{"BoostAODE", new bayesnet::BoostAODE()},
{"KDB", new bayesnet::KDB(2)}, {"KDBLd", new bayesnet::KDBLd(2)},
{"SPODE", new bayesnet::SPODE(1)}, {"SPODELd", new bayesnet::SPODELd(1)},
{"TAN", new bayesnet::TAN()}, {"TANLd", new bayesnet::TANLd()}
};
std::string name = GENERATE("AODE", "AODELd", "KDB", "KDBLd", "SPODE", "SPODELd", "TAN", "TANLd");
auto clf = models[name];
std::string file_name = GENERATE("glass", "iris", "ecoli", "diabetes");
auto raw = RawDatasets(file_name, false);
SECTION("Test TAN classifier (" + file_name + ")")
SECTION("Test " + name + " classifier")
{
auto clf = bayesnet::TAN();
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
auto score = clf.score(raw.Xv, raw.yv);
//scores[{file_name, "TAN"}] = score;
REQUIRE(score == Catch::Approx(scores[{file_name, "TAN"}]).epsilon(raw.epsilon));
for (const std::string& file_name : { "glass", "iris", "ecoli", "diabetes" }) {
auto clf = models[name];
auto discretize = name.substr(name.length() - 2) != "Ld";
auto raw = RawDatasets(file_name, discretize);
clf->fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
auto score = clf->score(raw.Xt, raw.yt);
INFO("File: " + file_name);
REQUIRE(score == Catch::Approx(scores[{file_name, name}]).epsilon(raw.epsilon));
}
}
SECTION("Test TANLd classifier (" + file_name + ")")
SECTION("Library check version")
{
auto clf = bayesnet::TANLd();
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
auto score = clf.score(raw.Xt, raw.yt);
//scores[{file_name, "TANLd"}] = score;
REQUIRE(score == Catch::Approx(scores[{file_name, "TANLd"}]).epsilon(raw.epsilon));
INFO("Checking version of " + name + " classifier");
REQUIRE(clf->getVersion() == ACTUAL_VERSION);
}
SECTION("Test KDB classifier (" + file_name + ")")
{
auto clf = bayesnet::KDB(2);
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
auto score = clf.score(raw.Xv, raw.yv);
//scores[{file_name, "KDB"}] = score;
REQUIRE(score == Catch::Approx(scores[{file_name, "KDB"
}]).epsilon(raw.epsilon));
}
SECTION("Test KDBLd classifier (" + file_name + ")")
{
auto clf = bayesnet::KDBLd(2);
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
auto score = clf.score(raw.Xt, raw.yt);
//scores[{file_name, "KDBLd"}] = score;
REQUIRE(score == Catch::Approx(scores[{file_name, "KDBLd"
}]).epsilon(raw.epsilon));
}
SECTION("Test SPODE classifier (" + file_name + ")")
{
auto clf = bayesnet::SPODE(1);
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
auto score = clf.score(raw.Xv, raw.yv);
// scores[{file_name, "SPODE"}] = score;
REQUIRE(score == Catch::Approx(scores[{file_name, "SPODE"}]).epsilon(raw.epsilon));
}
SECTION("Test SPODELd classifier (" + file_name + ")")
{
auto clf = bayesnet::SPODELd(1);
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
auto score = clf.score(raw.Xt, raw.yt);
// scores[{file_name, "SPODELd"}] = score;
REQUIRE(score == Catch::Approx(scores[{file_name, "SPODELd"}]).epsilon(raw.epsilon));
}
SECTION("Test AODE classifier (" + file_name + ")")
{
auto clf = bayesnet::AODE();
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
auto score = clf.score(raw.Xv, raw.yv);
// scores[{file_name, "AODE"}] = score;
REQUIRE(score == Catch::Approx(scores[{file_name, "AODE"}]).epsilon(raw.epsilon));
}
SECTION("Test AODELd classifier (" + file_name + ")")
{
auto clf = bayesnet::AODELd();
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
auto score = clf.score(raw.Xt, raw.yt);
// scores[{file_name, "AODELd"}] = score;
REQUIRE(score == Catch::Approx(scores[{file_name, "AODELd"}]).epsilon(raw.epsilon));
}
SECTION("Test BoostAODE classifier (" + file_name + ")")
{
auto clf = bayesnet::BoostAODE(true);
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
auto score = clf.score(raw.Xv, raw.yv);
// scores[{file_name, "BoostAODE"}] = score;
REQUIRE(score == Catch::Approx(scores[{file_name, "BoostAODE"}]).epsilon(raw.epsilon));
}
// for (auto scores : scores) {
// std::cout << "{{\"" << scores.first.first << "\", \"" << scores.first.second << "\"}, " << scores.second << "}, ";
// }
delete clf;
}
TEST_CASE("Models features", "[BayesNet]")
{
@ -264,3 +205,20 @@ TEST_CASE("Model predict_proba", "[BayesNet]")
delete clf;
}
}
TEST_CASE("BoostAODE voting-proba", "[BayesNet]")
{
auto raw = RawDatasets("iris", false);
auto clf = bayesnet::BoostAODE(false);
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
auto score_proba = clf.score(raw.Xv, raw.yv);
auto pred_proba = clf.predict_proba(raw.Xv);
clf.setHyperparameters({
{"predict_voting",true},
});
auto score_voting = clf.score(raw.Xv, raw.yv);
auto pred_voting = clf.predict_proba(raw.Xv);
REQUIRE(score_proba == Catch::Approx(0.97333).epsilon(raw.epsilon));
REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon));
REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon));
REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon));
}