Complete predict & predict_proba in ensemble
This commit is contained in:
parent
8477698d8d
commit
02e456befb
@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
- predict_proba method in Classifier
|
- predict_proba method in Classifier
|
||||||
- predict_proba method in BoostAODE
|
- predict_proba method in BoostAODE
|
||||||
- predict_voting parameter in BoostAODE constructor to use voting or probability to predict (default is voting)
|
- predict_voting parameter in BoostAODE constructor to use voting or probability to predict (default is voting)
|
||||||
|
- hyperparameter predict_voting to AODE, AODELd and BoostAODE (Ensemble child classes)
|
||||||
|
- tests to check predict & predict_proba coherence
|
||||||
|
|
||||||
## [1.0.2] - 2024-02-20
|
## [1.0.2] - 2024-02-20
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
cmake_minimum_required(VERSION 3.20)
|
cmake_minimum_required(VERSION 3.20)
|
||||||
|
|
||||||
project(BayesNet
|
project(BayesNet
|
||||||
VERSION 1.0.2
|
VERSION 1.0.3
|
||||||
DESCRIPTION "Bayesian Network and basic classifiers Library."
|
DESCRIPTION "Bayesian Network and basic classifiers Library."
|
||||||
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
|
HOMEPAGE_URL "https://github.com/rmontanana/bayesnet"
|
||||||
LANGUAGES CXX
|
LANGUAGES CXX
|
||||||
|
18
src/AODE.cc
18
src/AODE.cc
@ -1,10 +1,26 @@
|
|||||||
#include "AODE.h"
|
#include "AODE.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
AODE::AODE() : Ensemble() {}
|
AODE::AODE(bool predict_voting) : Ensemble(predict_voting)
|
||||||
|
{
|
||||||
|
validHyperparameters = { "predict_voting" };
|
||||||
|
|
||||||
|
}
|
||||||
|
void AODE::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||||
|
{
|
||||||
|
auto hyperparameters = hyperparameters_;
|
||||||
|
if (hyperparameters.contains("predict_voting")) {
|
||||||
|
predict_voting = hyperparameters["predict_voting"];
|
||||||
|
hyperparameters.erase("predict_voting");
|
||||||
|
}
|
||||||
|
if (!hyperparameters.empty()) {
|
||||||
|
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
|
||||||
|
}
|
||||||
|
}
|
||||||
void AODE::buildModel(const torch::Tensor& weights)
|
void AODE::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
models.clear();
|
models.clear();
|
||||||
|
significanceModels.clear();
|
||||||
for (int i = 0; i < features.size(); ++i) {
|
for (int i = 0; i < features.size(); ++i) {
|
||||||
models.push_back(std::make_unique<SPODE>(i));
|
models.push_back(std::make_unique<SPODE>(i));
|
||||||
}
|
}
|
||||||
|
@ -4,12 +4,13 @@
|
|||||||
#include "SPODE.h"
|
#include "SPODE.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class AODE : public Ensemble {
|
class AODE : public Ensemble {
|
||||||
|
public:
|
||||||
|
AODE(bool predict_voting = true);
|
||||||
|
virtual ~AODE() {};
|
||||||
|
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||||
|
std::vector<std::string> graph(const std::string& title = "AODE") const override;
|
||||||
protected:
|
protected:
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
public:
|
|
||||||
AODE();
|
|
||||||
virtual ~AODE() {};
|
|
||||||
std::vector<std::string> graph(const std::string& title = "AODE") const override;
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@ -1,7 +1,22 @@
|
|||||||
#include "AODELd.h"
|
#include "AODELd.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
AODELd::AODELd() : Ensemble(), Proposal(dataset, features, className) {}
|
AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className)
|
||||||
|
{
|
||||||
|
validHyperparameters = { "predict_voting" };
|
||||||
|
|
||||||
|
}
|
||||||
|
void AODELd::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||||
|
{
|
||||||
|
auto hyperparameters = hyperparameters_;
|
||||||
|
if (hyperparameters.contains("predict_voting")) {
|
||||||
|
predict_voting = hyperparameters["predict_voting"];
|
||||||
|
hyperparameters.erase("predict_voting");
|
||||||
|
}
|
||||||
|
if (!hyperparameters.empty()) {
|
||||||
|
throw std::invalid_argument("Invalid hyperparameters" + hyperparameters.dump());
|
||||||
|
}
|
||||||
|
}
|
||||||
AODELd& AODELd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
|
AODELd& AODELd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
|
||||||
{
|
{
|
||||||
checkInput(X_, y_);
|
checkInput(X_, y_);
|
||||||
|
12
src/AODELd.h
12
src/AODELd.h
@ -6,15 +6,15 @@
|
|||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class AODELd : public Ensemble, public Proposal {
|
class AODELd : public Ensemble, public Proposal {
|
||||||
|
public:
|
||||||
|
AODELd(bool predict_voting = true);
|
||||||
|
virtual ~AODELd() = default;
|
||||||
|
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_) override;
|
||||||
|
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||||
|
std::vector<std::string> graph(const std::string& name = "AODELd") const override;
|
||||||
protected:
|
protected:
|
||||||
void trainModel(const torch::Tensor& weights) override;
|
void trainModel(const torch::Tensor& weights) override;
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
public:
|
|
||||||
AODELd();
|
|
||||||
AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_) override;
|
|
||||||
virtual ~AODELd() = default;
|
|
||||||
std::vector<std::string> graph(const std::string& name = "AODELd") const override;
|
|
||||||
static inline std::string version() { return "0.0.1"; };
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif // !AODELD_H
|
#endif // !AODELD_H
|
@ -10,13 +10,14 @@
|
|||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)
|
BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)
|
||||||
{
|
{
|
||||||
validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features", "tolerance" };
|
validHyperparameters = { "repeatSparent", "maxModels", "ascending", "convergence", "threshold", "select_features", "tolerance", "predict_voting" };
|
||||||
|
|
||||||
}
|
}
|
||||||
void BoostAODE::buildModel(const torch::Tensor& weights)
|
void BoostAODE::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
// Models shall be built in trainModel
|
// Models shall be built in trainModel
|
||||||
models.clear();
|
models.clear();
|
||||||
|
significanceModels.clear();
|
||||||
n_models = 0;
|
n_models = 0;
|
||||||
// Prepare the validation dataset
|
// Prepare the validation dataset
|
||||||
auto y_ = dataset.index({ -1, "..." });
|
auto y_ = dataset.index({ -1, "..." });
|
||||||
@ -72,6 +73,10 @@ namespace bayesnet {
|
|||||||
tolerance = hyperparameters["tolerance"];
|
tolerance = hyperparameters["tolerance"];
|
||||||
hyperparameters.erase("tolerance");
|
hyperparameters.erase("tolerance");
|
||||||
}
|
}
|
||||||
|
if (hyperparameters.contains("predict_voting")) {
|
||||||
|
predict_voting = hyperparameters["predict_voting"];
|
||||||
|
hyperparameters.erase("predict_voting");
|
||||||
|
}
|
||||||
if (hyperparameters.contains("select_features")) {
|
if (hyperparameters.contains("select_features")) {
|
||||||
auto selectedAlgorithm = hyperparameters["select_features"];
|
auto selectedAlgorithm = hyperparameters["select_features"];
|
||||||
std::vector<std::string> algos = { "IWSS", "FCBF", "CFS" };
|
std::vector<std::string> algos = { "IWSS", "FCBF", "CFS" };
|
||||||
@ -128,8 +133,11 @@ namespace bayesnet {
|
|||||||
if (selectFeatures) {
|
if (selectFeatures) {
|
||||||
featuresUsed = initializeModels();
|
featuresUsed = initializeModels();
|
||||||
}
|
}
|
||||||
if (maxModels == 0)
|
bool resetMaxModels = false;
|
||||||
|
if (maxModels == 0) {
|
||||||
maxModels = .1 * n > 10 ? .1 * n : n;
|
maxModels = .1 * n > 10 ? .1 * n : n;
|
||||||
|
resetMaxModels = true; // Flag to unset maxModels
|
||||||
|
}
|
||||||
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
|
||||||
bool exitCondition = false;
|
bool exitCondition = false;
|
||||||
// Variables to control the accuracy finish condition
|
// Variables to control the accuracy finish condition
|
||||||
@ -211,6 +219,9 @@ namespace bayesnet {
|
|||||||
status = WARNING;
|
status = WARNING;
|
||||||
}
|
}
|
||||||
notes.push_back("Number of models: " + std::to_string(n_models));
|
notes.push_back("Number of models: " + std::to_string(n_models));
|
||||||
|
if (resetMaxModels) {
|
||||||
|
maxModels = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::vector<std::string> BoostAODE::graph(const std::string& title) const
|
std::vector<std::string> BoostAODE::graph(const std::string& title) const
|
||||||
{
|
{
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class BoostAODE : public Ensemble {
|
class BoostAODE : public Ensemble {
|
||||||
public:
|
public:
|
||||||
BoostAODE(bool predict_voting = false);
|
BoostAODE(bool predict_voting = true);
|
||||||
virtual ~BoostAODE() = default;
|
virtual ~BoostAODE() = default;
|
||||||
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
|
std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
|
||||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||||
|
@ -2,9 +2,6 @@
|
|||||||
#include <catch2/catch_test_macros.hpp>
|
#include <catch2/catch_test_macros.hpp>
|
||||||
#include <catch2/catch_approx.hpp>
|
#include <catch2/catch_approx.hpp>
|
||||||
#include <catch2/generators/catch_generators.hpp>
|
#include <catch2/generators/catch_generators.hpp>
|
||||||
#include <vector>
|
|
||||||
#include <map>
|
|
||||||
#include <string>
|
|
||||||
#include "KDB.h"
|
#include "KDB.h"
|
||||||
#include "TAN.h"
|
#include "TAN.h"
|
||||||
#include "SPODE.h"
|
#include "SPODE.h"
|
||||||
@ -16,12 +13,9 @@
|
|||||||
#include "AODELd.h"
|
#include "AODELd.h"
|
||||||
#include "TestUtils.h"
|
#include "TestUtils.h"
|
||||||
|
|
||||||
TEST_CASE("Library check version", "[BayesNet]")
|
const std::string ACTUAL_VERSION = "1.0.3";
|
||||||
{
|
|
||||||
auto clf = bayesnet::KDB(2);
|
TEST_CASE("Test Bayesian Classifiers score & version", "[BayesNet]")
|
||||||
REQUIRE(clf.getVersion() == "1.0.2");
|
|
||||||
}
|
|
||||||
TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
|
|
||||||
{
|
{
|
||||||
map <pair<std::string, std::string>, float> scores = {
|
map <pair<std::string, std::string>, float> scores = {
|
||||||
// Diabetes
|
// Diabetes
|
||||||
@ -37,87 +31,34 @@ TEST_CASE("Test Bayesian Classifiers score", "[BayesNet]")
|
|||||||
{{"iris", "AODE"}, 0.973333}, {{"iris", "KDB"}, 0.973333}, {{"iris", "SPODE"}, 0.973333}, {{"iris", "TAN"}, 0.973333},
|
{{"iris", "AODE"}, 0.973333}, {{"iris", "KDB"}, 0.973333}, {{"iris", "SPODE"}, 0.973333}, {{"iris", "TAN"}, 0.973333},
|
||||||
{{"iris", "AODELd"}, 0.973333}, {{"iris", "KDBLd"}, 0.973333}, {{"iris", "SPODELd"}, 0.96f}, {{"iris", "TANLd"}, 0.97333f}, {{"iris", "BoostAODE"}, 0.98f}
|
{{"iris", "AODELd"}, 0.973333}, {{"iris", "KDBLd"}, 0.973333}, {{"iris", "SPODELd"}, 0.96f}, {{"iris", "TANLd"}, 0.97333f}, {{"iris", "BoostAODE"}, 0.98f}
|
||||||
};
|
};
|
||||||
|
std::map<std::string, bayesnet::BaseClassifier*> models = {
|
||||||
|
{"AODE", new bayesnet::AODE()}, {"AODELd", new bayesnet::AODELd()},
|
||||||
|
{"BoostAODE", new bayesnet::BoostAODE()},
|
||||||
|
{"KDB", new bayesnet::KDB(2)}, {"KDBLd", new bayesnet::KDBLd(2)},
|
||||||
|
{"SPODE", new bayesnet::SPODE(1)}, {"SPODELd", new bayesnet::SPODELd(1)},
|
||||||
|
{"TAN", new bayesnet::TAN()}, {"TANLd", new bayesnet::TANLd()}
|
||||||
|
};
|
||||||
|
std::string name = GENERATE("AODE", "AODELd", "KDB", "KDBLd", "SPODE", "SPODELd", "TAN", "TANLd");
|
||||||
|
auto clf = models[name];
|
||||||
|
|
||||||
std::string file_name = GENERATE("glass", "iris", "ecoli", "diabetes");
|
SECTION("Test " + name + " classifier")
|
||||||
auto raw = RawDatasets(file_name, false);
|
|
||||||
|
|
||||||
SECTION("Test TAN classifier (" + file_name + ")")
|
|
||||||
{
|
{
|
||||||
auto clf = bayesnet::TAN();
|
for (const std::string& file_name : { "glass", "iris", "ecoli", "diabetes" }) {
|
||||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
auto clf = models[name];
|
||||||
auto score = clf.score(raw.Xv, raw.yv);
|
auto discretize = name.substr(name.length() - 2) != "Ld";
|
||||||
//scores[{file_name, "TAN"}] = score;
|
auto raw = RawDatasets(file_name, discretize);
|
||||||
REQUIRE(score == Catch::Approx(scores[{file_name, "TAN"}]).epsilon(raw.epsilon));
|
clf->fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||||
|
auto score = clf->score(raw.Xt, raw.yt);
|
||||||
|
INFO("File: " + file_name);
|
||||||
|
REQUIRE(score == Catch::Approx(scores[{file_name, name}]).epsilon(raw.epsilon));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
SECTION("Test TANLd classifier (" + file_name + ")")
|
SECTION("Library check version")
|
||||||
{
|
{
|
||||||
auto clf = bayesnet::TANLd();
|
INFO("Checking version of " + name + " classifier");
|
||||||
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
REQUIRE(clf->getVersion() == ACTUAL_VERSION);
|
||||||
auto score = clf.score(raw.Xt, raw.yt);
|
|
||||||
//scores[{file_name, "TANLd"}] = score;
|
|
||||||
REQUIRE(score == Catch::Approx(scores[{file_name, "TANLd"}]).epsilon(raw.epsilon));
|
|
||||||
}
|
}
|
||||||
SECTION("Test KDB classifier (" + file_name + ")")
|
delete clf;
|
||||||
{
|
|
||||||
auto clf = bayesnet::KDB(2);
|
|
||||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
|
||||||
auto score = clf.score(raw.Xv, raw.yv);
|
|
||||||
//scores[{file_name, "KDB"}] = score;
|
|
||||||
REQUIRE(score == Catch::Approx(scores[{file_name, "KDB"
|
|
||||||
}]).epsilon(raw.epsilon));
|
|
||||||
}
|
|
||||||
SECTION("Test KDBLd classifier (" + file_name + ")")
|
|
||||||
{
|
|
||||||
auto clf = bayesnet::KDBLd(2);
|
|
||||||
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
|
||||||
auto score = clf.score(raw.Xt, raw.yt);
|
|
||||||
//scores[{file_name, "KDBLd"}] = score;
|
|
||||||
REQUIRE(score == Catch::Approx(scores[{file_name, "KDBLd"
|
|
||||||
}]).epsilon(raw.epsilon));
|
|
||||||
}
|
|
||||||
SECTION("Test SPODE classifier (" + file_name + ")")
|
|
||||||
{
|
|
||||||
auto clf = bayesnet::SPODE(1);
|
|
||||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
|
||||||
auto score = clf.score(raw.Xv, raw.yv);
|
|
||||||
// scores[{file_name, "SPODE"}] = score;
|
|
||||||
REQUIRE(score == Catch::Approx(scores[{file_name, "SPODE"}]).epsilon(raw.epsilon));
|
|
||||||
}
|
|
||||||
SECTION("Test SPODELd classifier (" + file_name + ")")
|
|
||||||
{
|
|
||||||
auto clf = bayesnet::SPODELd(1);
|
|
||||||
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
|
||||||
auto score = clf.score(raw.Xt, raw.yt);
|
|
||||||
// scores[{file_name, "SPODELd"}] = score;
|
|
||||||
REQUIRE(score == Catch::Approx(scores[{file_name, "SPODELd"}]).epsilon(raw.epsilon));
|
|
||||||
}
|
|
||||||
SECTION("Test AODE classifier (" + file_name + ")")
|
|
||||||
{
|
|
||||||
auto clf = bayesnet::AODE();
|
|
||||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
|
||||||
auto score = clf.score(raw.Xv, raw.yv);
|
|
||||||
// scores[{file_name, "AODE"}] = score;
|
|
||||||
REQUIRE(score == Catch::Approx(scores[{file_name, "AODE"}]).epsilon(raw.epsilon));
|
|
||||||
}
|
|
||||||
SECTION("Test AODELd classifier (" + file_name + ")")
|
|
||||||
{
|
|
||||||
auto clf = bayesnet::AODELd();
|
|
||||||
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
|
||||||
auto score = clf.score(raw.Xt, raw.yt);
|
|
||||||
// scores[{file_name, "AODELd"}] = score;
|
|
||||||
REQUIRE(score == Catch::Approx(scores[{file_name, "AODELd"}]).epsilon(raw.epsilon));
|
|
||||||
}
|
|
||||||
SECTION("Test BoostAODE classifier (" + file_name + ")")
|
|
||||||
{
|
|
||||||
auto clf = bayesnet::BoostAODE(true);
|
|
||||||
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
|
||||||
auto score = clf.score(raw.Xv, raw.yv);
|
|
||||||
// scores[{file_name, "BoostAODE"}] = score;
|
|
||||||
REQUIRE(score == Catch::Approx(scores[{file_name, "BoostAODE"}]).epsilon(raw.epsilon));
|
|
||||||
}
|
|
||||||
// for (auto scores : scores) {
|
|
||||||
// std::cout << "{{\"" << scores.first.first << "\", \"" << scores.first.second << "\"}, " << scores.second << "}, ";
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
TEST_CASE("Models features", "[BayesNet]")
|
TEST_CASE("Models features", "[BayesNet]")
|
||||||
{
|
{
|
||||||
@ -264,3 +205,20 @@ TEST_CASE("Model predict_proba", "[BayesNet]")
|
|||||||
delete clf;
|
delete clf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
TEST_CASE("BoostAODE voting-proba", "[BayesNet]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("iris", false);
|
||||||
|
auto clf = bayesnet::BoostAODE(false);
|
||||||
|
clf.fit(raw.Xv, raw.yv, raw.featuresv, raw.classNamev, raw.statesv);
|
||||||
|
auto score_proba = clf.score(raw.Xv, raw.yv);
|
||||||
|
auto pred_proba = clf.predict_proba(raw.Xv);
|
||||||
|
clf.setHyperparameters({
|
||||||
|
{"predict_voting",true},
|
||||||
|
});
|
||||||
|
auto score_voting = clf.score(raw.Xv, raw.yv);
|
||||||
|
auto pred_voting = clf.predict_proba(raw.Xv);
|
||||||
|
REQUIRE(score_proba == Catch::Approx(0.97333).epsilon(raw.epsilon));
|
||||||
|
REQUIRE(score_voting == Catch::Approx(0.98).epsilon(raw.epsilon));
|
||||||
|
REQUIRE(pred_voting[83][2] == Catch::Approx(0.552091).epsilon(raw.epsilon));
|
||||||
|
REQUIRE(pred_proba[83][2] == Catch::Approx(0.546017).epsilon(raw.epsilon));
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user