Compare commits

...

6 Commits

17 changed files with 144 additions and 24 deletions

2
.gitmodules vendored
View File

@ -18,4 +18,4 @@
url = https://github.com/rmontanana/ArffFiles
[submodule "lib/mdlp"]
path = lib/mdlp
url = https://github.com/rmontanana/mdlp
url = https://github.com/rmontanana/mdlp

View File

@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- Add a new hyperparameter to the BoostAODE class, *alphablock*, to control the way α is computed, with the last model or with the ensmble built so far. Default value is *false*.
- Add a new hyperparameter to the SPODE class, *parent*, to set the root node of the model. If not value is set the root parameter of the constructor is used.
- Add a new hyperparameter to the TAN class, *parent*, to set the root node of the model. If not set the first feature is used as root.
## [1.0.6] 2024-11-23
### Fixed

View File

@ -18,7 +18,7 @@ The only external dependency is [libtorch](https://pytorch.org/cppdocs/installin
```bash
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zips
unzip libtorch-shared-with-deps-latest.zip
```
## Setup

View File

@ -8,14 +8,29 @@
namespace bayesnet {
SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
SPODE::SPODE(int root) : Classifier(Network()), root(root)
{
validHyperparameters = { "parent" };
}
void SPODE::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("parent")) {
root = hyperparameters["parent"];
hyperparameters.erase("parent");
}
Classifier::setHyperparameters(hyperparameters);
}
void SPODE::buildModel(const torch::Tensor& weights)
{
// 0. Add all nodes to the model
addNodes();
// 1. Add edges from the class node to all other nodes
// 2. Add edges from the root node to all other nodes
if (root >= static_cast<int>(features.size())) {
throw std::invalid_argument("The parent node is not in the dataset");
}
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
model.addEdge(className, features[i]);
if (i != root) {

View File

@ -10,14 +10,15 @@
namespace bayesnet {
class SPODE : public Classifier {
private:
int root;
protected:
void buildModel(const torch::Tensor& weights) override;
public:
explicit SPODE(int root);
virtual ~SPODE() = default;
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
std::vector<std::string> graph(const std::string& name = "SPODE") const override;
protected:
void buildModel(const torch::Tensor& weights) override;
private:
int root;
};
}
#endif

View File

@ -7,8 +7,20 @@
#include "TAN.h"
namespace bayesnet {
TAN::TAN() : Classifier(Network()) {}
TAN::TAN() : Classifier(Network())
{
validHyperparameters = { "parent" };
}
void TAN::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("parent")) {
parent = hyperparameters["parent"];
hyperparameters.erase("parent");
}
Classifier::setHyperparameters(hyperparameters);
}
void TAN::buildModel(const torch::Tensor& weights)
{
// 0. Add all nodes to the model
@ -23,7 +35,10 @@ namespace bayesnet {
mi.push_back({ i, mi_value });
}
sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
auto root = mi[mi.size() - 1].first;
auto root = parent == -1 ? mi[mi.size() - 1].first : parent;
if (root >= static_cast<int>(features.size())) {
throw std::invalid_argument("The parent node is not in the dataset");
}
// 2. Compute mutual information between each feature and the class
auto weights_matrix = metrics.conditionalEdge(weights);
// 3. Compute the maximum spanning tree

View File

@ -9,13 +9,15 @@
#include "Classifier.h"
namespace bayesnet {
class TAN : public Classifier {
private:
protected:
void buildModel(const torch::Tensor& weights) override;
public:
TAN();
virtual ~TAN() = default;
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
std::vector<std::string> graph(const std::string& name = "TAN") const override;
protected:
void buildModel(const torch::Tensor& weights) override;
private:
int parent = -1;
};
}
#endif

View File

@ -12,7 +12,7 @@
namespace bayesnet {
Boost::Boost(bool predict_voting) : Ensemble(predict_voting)
{
validHyperparameters = { "order", "convergence", "convergence_best", "bisection", "threshold", "maxTolerance",
validHyperparameters = { "alpha_block", "order", "convergence", "convergence_best", "bisection", "threshold", "maxTolerance",
"predict_voting", "select_features", "block_update" };
}
void Boost::setHyperparameters(const nlohmann::json& hyperparameters_)
@ -26,6 +26,10 @@ namespace bayesnet {
}
hyperparameters.erase("order");
}
if (hyperparameters.contains("alpha_block")) {
alpha_block = hyperparameters["alpha_block"];
hyperparameters.erase("alpha_block");
}
if (hyperparameters.contains("convergence")) {
convergence = hyperparameters["convergence"];
hyperparameters.erase("convergence");
@ -66,6 +70,12 @@ namespace bayesnet {
block_update = hyperparameters["block_update"];
hyperparameters.erase("block_update");
}
if (block_update && alpha_block) {
throw std::invalid_argument("alpha_block and block_update cannot be true at the same time");
}
if (block_update && !bisection) {
throw std::invalid_argument("block_update needs bisection to be true");
}
Classifier::setHyperparameters(hyperparameters);
}
void Boost::buildModel(const torch::Tensor& weights)

View File

@ -45,8 +45,8 @@ namespace bayesnet {
std::string select_features_algorithm = Orders.DESC; // Selected feature selection algorithm
FeatureSelect* featureSelector = nullptr;
double threshold = -1;
bool block_update = false;
bool block_update = false; // if true, use block update algorithm, only meaningful if bisection is true
bool alpha_block = false; // if true, the alpha is computed with the ensemble built so far and the new model
};
}
#endif

View File

@ -92,7 +92,25 @@ namespace bayesnet {
model->fit(dataset, features, className, states, weights_, smoothing);
alpha_t = 0.0;
if (!block_update) {
auto ypred = model->predict(X_train);
torch::Tensor ypred;
if (alpha_block) {
//
// Compute the prediction with the current ensemble + model
//
// Add the model to the ensemble
n_models++;
models.push_back(std::move(model));
significanceModels.push_back(1);
// Compute the prediction
ypred = predict(X_train);
// Remove the model from the ensemble
model = std::move(models.back());
models.pop_back();
significanceModels.pop_back();
n_models--;
} else {
ypred = model->predict(X_train);
}
// Step 3.1: Compute the classifier amout of say
std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
}

@ -1 +0,0 @@
Subproject commit 029fe3b4609dd84cd939b73357f37bbb75bcf82f

@ -1 +1 @@
Subproject commit 2ac43e32ac1eac0c986702ec526cf5367a565ef0
Subproject commit 9652853d692ed3b8a38d89f70559209ffb988020

@ -1 +1 @@
Subproject commit 378e091795a70fced276cd882bd8a6a428668fe5
Subproject commit 620034ececc93991c5c1183b73c3768d81ca84b3

View File

@ -267,4 +267,36 @@ TEST_CASE("Predict, predict_proba & score without fitting", "[Models]")
REQUIRE_THROWS_WITH(clf.predict_proba(raw.Xt), message);
REQUIRE_THROWS_WITH(clf.score(raw.Xv, raw.yv), message);
REQUIRE_THROWS_WITH(clf.score(raw.Xt, raw.yt), message);
}
TEST_CASE("TAN & SPODE with hyperparameters", "[Models]")
{
auto raw = RawDatasets("iris", true);
auto clf = bayesnet::TAN();
clf.setHyperparameters({
{"parent", 1},
});
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
auto score = clf.score(raw.Xv, raw.yv);
REQUIRE(score == Catch::Approx(0.973333).epsilon(raw.epsilon));
auto clf2 = bayesnet::SPODE(0);
clf2.setHyperparameters({
{"parent", 1},
});
clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
auto score2 = clf2.score(raw.Xv, raw.yv);
REQUIRE(score2 == Catch::Approx(0.973333).epsilon(raw.epsilon));
}
TEST_CASE("TAN & SPODE with invalid hyperparameters", "[Models]")
{
auto raw = RawDatasets("iris", true);
auto clf = bayesnet::TAN();
clf.setHyperparameters({
{"parent", 5},
});
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
auto clf2 = bayesnet::SPODE(0);
clf2.setHyperparameters({
{"parent", 5},
});
REQUIRE_THROWS_AS(clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
}

View File

@ -7,7 +7,7 @@
#include <type_traits>
#include <catch2/catch_test_macros.hpp>
#include <catch2/catch_approx.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <catch2/generators/catch_generators.hpp>
#include <catch2/matchers/catch_matchers.hpp>
#include "bayesnet/ensembles/BoostAODE.h"
#include "TestUtils.h"
@ -136,8 +136,16 @@ TEST_CASE("Oddities", "[BoostAODE]")
clf.setHyperparameters(hyper.value());
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
}
}
auto bad_hyper_fit2 = nlohmann::json{
{ { "alpha_block", true }, { "block_update", true } },
{ { "bisection", false }, { "block_update", true } },
};
for (const auto& hyper : bad_hyper_fit2.items()) {
INFO("BoostAODE hyper: " << hyper.value().dump());
REQUIRE_THROWS_AS(clf.setHyperparameters(hyper.value()), std::invalid_argument);
}
}
TEST_CASE("Bisection Best", "[BoostAODE]")
{
auto clf = bayesnet::BoostAODE();
@ -180,7 +188,6 @@ TEST_CASE("Bisection Best vs Last", "[BoostAODE]")
auto score_last = clf.score(raw.X_test, raw.y_test);
REQUIRE(score_last == Catch::Approx(0.976666689f).epsilon(raw.epsilon));
}
TEST_CASE("Block Update", "[BoostAODE]")
{
auto clf = bayesnet::BoostAODE();
@ -210,4 +217,19 @@ TEST_CASE("Block Update", "[BoostAODE]")
// std::cout << note << std::endl;
// }
// std::cout << "Score " << score << std::endl;
}
TEST_CASE("Alphablock", "[BoostAODE]")
{
auto clf_alpha = bayesnet::BoostAODE();
auto clf_no_alpha = bayesnet::BoostAODE();
auto raw = RawDatasets("diabetes", true);
clf_alpha.setHyperparameters({
{"alpha_block", true},
});
clf_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
clf_no_alpha.fit(raw.X_train, raw.y_train, raw.features, raw.className, raw.states, raw.smoothing);
auto score_alpha = clf_alpha.score(raw.X_test, raw.y_test);
auto score_no_alpha = clf_no_alpha.score(raw.X_test, raw.y_test);
REQUIRE(score_alpha == Catch::Approx(0.720779f).epsilon(raw.epsilon));
REQUIRE(score_no_alpha == Catch::Approx(0.733766f).epsilon(raw.epsilon));
}

View File

@ -17,7 +17,7 @@
std::map<std::string, std::string> modules = {
{ "mdlp", "2.0.1" },
{ "Folding", "1.1.0" },
{ "Folding", "1.1.1" },
{ "json", "3.11" },
{ "ArffFiles", "1.1.0" }
};

@ -1 +1 @@
Subproject commit 506276c59217429c93abd2fe9507c7f45eb81072
Subproject commit 0321d2fce328b5e2ad106a8230ff20e0d5bf5501