diff --git a/CHANGELOG.md b/CHANGELOG.md index 11ad42a..f2b74c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add a new hyperparameter to the BoostAODE class, *alphablock*, to control the way α is computed, with the last model or with the ensmble built so far. Default value is *false*. +- Add a new hyperparameter to the SPODE class, *parent*, to set the root node of the model. If not value is set the root parameter of the constructor is used. +- Add a new hyperparameter to the TAN class, *parent*, to set the root node of the model. If not set the first feature is used as root. ## [1.0.6] 2024-11-23 diff --git a/bayesnet/classifiers/SPODE.cc b/bayesnet/classifiers/SPODE.cc index 7736e7e..a33d284 100644 --- a/bayesnet/classifiers/SPODE.cc +++ b/bayesnet/classifiers/SPODE.cc @@ -8,14 +8,29 @@ namespace bayesnet { - SPODE::SPODE(int root) : Classifier(Network()), root(root) {} + SPODE::SPODE(int root) : Classifier(Network()), root(root) + { + validHyperparameters = { "parent" }; + } + void SPODE::setHyperparameters(const nlohmann::json& hyperparameters_) + { + auto hyperparameters = hyperparameters_; + if (hyperparameters.contains("parent")) { + root = hyperparameters["parent"]; + hyperparameters.erase("parent"); + } + Classifier::setHyperparameters(hyperparameters); + } void SPODE::buildModel(const torch::Tensor& weights) { // 0. Add all nodes to the model addNodes(); // 1. Add edges from the class node to all other nodes // 2. Add edges from the root node to all other nodes + if (root >= static_cast(features.size())) { + throw std::invalid_argument("The parent node is not in the dataset"); + } for (int i = 0; i < static_cast(features.size()); ++i) { model.addEdge(className, features[i]); if (i != root) { diff --git a/bayesnet/classifiers/SPODE.h b/bayesnet/classifiers/SPODE.h index 7ecff63..67a1f49 100644 --- a/bayesnet/classifiers/SPODE.h +++ b/bayesnet/classifiers/SPODE.h @@ -10,14 +10,15 @@ namespace bayesnet { class SPODE : public Classifier { - private: - int root; - protected: - void buildModel(const torch::Tensor& weights) override; public: explicit SPODE(int root); virtual ~SPODE() = default; + void setHyperparameters(const nlohmann::json& hyperparameters_) override; std::vector graph(const std::string& name = "SPODE") const override; + protected: + void buildModel(const torch::Tensor& weights) override; + private: + int root; }; } #endif \ No newline at end of file diff --git a/bayesnet/classifiers/TAN.cc b/bayesnet/classifiers/TAN.cc index d2be0c7..2ec10eb 100644 --- a/bayesnet/classifiers/TAN.cc +++ b/bayesnet/classifiers/TAN.cc @@ -7,8 +7,20 @@ #include "TAN.h" namespace bayesnet { - TAN::TAN() : Classifier(Network()) {} + TAN::TAN() : Classifier(Network()) + { + validHyperparameters = { "parent" }; + } + void TAN::setHyperparameters(const nlohmann::json& hyperparameters_) + { + auto hyperparameters = hyperparameters_; + if (hyperparameters.contains("parent")) { + parent = hyperparameters["parent"]; + hyperparameters.erase("parent"); + } + Classifier::setHyperparameters(hyperparameters); + } void TAN::buildModel(const torch::Tensor& weights) { // 0. Add all nodes to the model @@ -23,7 +35,10 @@ namespace bayesnet { mi.push_back({ i, mi_value }); } sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;}); - auto root = mi[mi.size() - 1].first; + auto root = parent == -1 ? mi[mi.size() - 1].first : parent; + if (root >= static_cast(features.size())) { + throw std::invalid_argument("The parent node is not in the dataset"); + } // 2. Compute mutual information between each feature and the class auto weights_matrix = metrics.conditionalEdge(weights); // 3. Compute the maximum spanning tree diff --git a/bayesnet/classifiers/TAN.h b/bayesnet/classifiers/TAN.h index 00d50f9..b68423e 100644 --- a/bayesnet/classifiers/TAN.h +++ b/bayesnet/classifiers/TAN.h @@ -9,13 +9,15 @@ #include "Classifier.h" namespace bayesnet { class TAN : public Classifier { - private: - protected: - void buildModel(const torch::Tensor& weights) override; public: TAN(); virtual ~TAN() = default; + void setHyperparameters(const nlohmann::json& hyperparameters_) override; std::vector graph(const std::string& name = "TAN") const override; + protected: + void buildModel(const torch::Tensor& weights) override; + private: + int parent = -1; }; } #endif \ No newline at end of file diff --git a/tests/TestBayesModels.cc b/tests/TestBayesModels.cc index e5113a2..50616ca 100644 --- a/tests/TestBayesModels.cc +++ b/tests/TestBayesModels.cc @@ -267,4 +267,36 @@ TEST_CASE("Predict, predict_proba & score without fitting", "[Models]") REQUIRE_THROWS_WITH(clf.predict_proba(raw.Xt), message); REQUIRE_THROWS_WITH(clf.score(raw.Xv, raw.yv), message); REQUIRE_THROWS_WITH(clf.score(raw.Xt, raw.yt), message); +} +TEST_CASE("TAN & SPODE with hyperparameters", "[Models]") +{ + auto raw = RawDatasets("iris", true); + auto clf = bayesnet::TAN(); + clf.setHyperparameters({ + {"parent", 1}, + }); + clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing); + auto score = clf.score(raw.Xv, raw.yv); + REQUIRE(score == Catch::Approx(0.973333).epsilon(raw.epsilon)); + auto clf2 = bayesnet::SPODE(0); + clf2.setHyperparameters({ + {"parent", 1}, + }); + clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing); + auto score2 = clf2.score(raw.Xv, raw.yv); + REQUIRE(score2 == Catch::Approx(0.973333).epsilon(raw.epsilon)); +} +TEST_CASE("TAN & SPODE with invalid hyperparameters", "[Models]") +{ + auto raw = RawDatasets("iris", true); + auto clf = bayesnet::TAN(); + clf.setHyperparameters({ + {"parent", 5}, + }); + REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument); + auto clf2 = bayesnet::SPODE(0); + clf2.setHyperparameters({ + {"parent", 5}, + }); + REQUIRE_THROWS_AS(clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument); } \ No newline at end of file diff --git a/tests/TestBoostAODE.cc b/tests/TestBoostAODE.cc index 625ff7f..728f35b 100644 --- a/tests/TestBoostAODE.cc +++ b/tests/TestBoostAODE.cc @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include #include "bayesnet/ensembles/BoostAODE.h" #include "TestUtils.h"