Add parent hyperparameter to TAN & SPODE
This commit is contained in:
parent
56a2d3ead0
commit
e2781ee525
@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
### Added
|
||||
|
||||
- Add a new hyperparameter to the BoostAODE class, *alphablock*, to control the way α is computed, with the last model or with the ensmble built so far. Default value is *false*.
|
||||
- Add a new hyperparameter to the SPODE class, *parent*, to set the root node of the model. If not value is set the root parameter of the constructor is used.
|
||||
- Add a new hyperparameter to the TAN class, *parent*, to set the root node of the model. If not set the first feature is used as root.
|
||||
|
||||
## [1.0.6] 2024-11-23
|
||||
|
||||
|
@ -8,14 +8,29 @@
|
||||
|
||||
namespace bayesnet {
|
||||
|
||||
SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
|
||||
SPODE::SPODE(int root) : Classifier(Network()), root(root)
|
||||
{
|
||||
validHyperparameters = { "parent" };
|
||||
}
|
||||
|
||||
void SPODE::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("parent")) {
|
||||
root = hyperparameters["parent"];
|
||||
hyperparameters.erase("parent");
|
||||
}
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
}
|
||||
void SPODE::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// 0. Add all nodes to the model
|
||||
addNodes();
|
||||
// 1. Add edges from the class node to all other nodes
|
||||
// 2. Add edges from the root node to all other nodes
|
||||
if (root >= static_cast<int>(features.size())) {
|
||||
throw std::invalid_argument("The parent node is not in the dataset");
|
||||
}
|
||||
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
|
||||
model.addEdge(className, features[i]);
|
||||
if (i != root) {
|
||||
|
@ -10,14 +10,15 @@
|
||||
|
||||
namespace bayesnet {
|
||||
class SPODE : public Classifier {
|
||||
private:
|
||||
int root;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
public:
|
||||
explicit SPODE(int root);
|
||||
virtual ~SPODE() = default;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||
std::vector<std::string> graph(const std::string& name = "SPODE") const override;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
private:
|
||||
int root;
|
||||
};
|
||||
}
|
||||
#endif
|
@ -7,8 +7,20 @@
|
||||
#include "TAN.h"
|
||||
|
||||
namespace bayesnet {
|
||||
TAN::TAN() : Classifier(Network()) {}
|
||||
TAN::TAN() : Classifier(Network())
|
||||
{
|
||||
validHyperparameters = { "parent" };
|
||||
}
|
||||
|
||||
void TAN::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("parent")) {
|
||||
parent = hyperparameters["parent"];
|
||||
hyperparameters.erase("parent");
|
||||
}
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
}
|
||||
void TAN::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// 0. Add all nodes to the model
|
||||
@ -23,7 +35,10 @@ namespace bayesnet {
|
||||
mi.push_back({ i, mi_value });
|
||||
}
|
||||
sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
|
||||
auto root = mi[mi.size() - 1].first;
|
||||
auto root = parent == -1 ? mi[mi.size() - 1].first : parent;
|
||||
if (root >= static_cast<int>(features.size())) {
|
||||
throw std::invalid_argument("The parent node is not in the dataset");
|
||||
}
|
||||
// 2. Compute mutual information between each feature and the class
|
||||
auto weights_matrix = metrics.conditionalEdge(weights);
|
||||
// 3. Compute the maximum spanning tree
|
||||
|
@ -9,13 +9,15 @@
|
||||
#include "Classifier.h"
|
||||
namespace bayesnet {
|
||||
class TAN : public Classifier {
|
||||
private:
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
public:
|
||||
TAN();
|
||||
virtual ~TAN() = default;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||
std::vector<std::string> graph(const std::string& name = "TAN") const override;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
private:
|
||||
int parent = -1;
|
||||
};
|
||||
}
|
||||
#endif
|
@ -267,4 +267,36 @@ TEST_CASE("Predict, predict_proba & score without fitting", "[Models]")
|
||||
REQUIRE_THROWS_WITH(clf.predict_proba(raw.Xt), message);
|
||||
REQUIRE_THROWS_WITH(clf.score(raw.Xv, raw.yv), message);
|
||||
REQUIRE_THROWS_WITH(clf.score(raw.Xt, raw.yt), message);
|
||||
}
|
||||
TEST_CASE("TAN & SPODE with hyperparameters", "[Models]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = bayesnet::TAN();
|
||||
clf.setHyperparameters({
|
||||
{"parent", 1},
|
||||
});
|
||||
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score = clf.score(raw.Xv, raw.yv);
|
||||
REQUIRE(score == Catch::Approx(0.973333).epsilon(raw.epsilon));
|
||||
auto clf2 = bayesnet::SPODE(0);
|
||||
clf2.setHyperparameters({
|
||||
{"parent", 1},
|
||||
});
|
||||
clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score2 = clf2.score(raw.Xv, raw.yv);
|
||||
REQUIRE(score2 == Catch::Approx(0.973333).epsilon(raw.epsilon));
|
||||
}
|
||||
TEST_CASE("TAN & SPODE with invalid hyperparameters", "[Models]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = bayesnet::TAN();
|
||||
clf.setHyperparameters({
|
||||
{"parent", 5},
|
||||
});
|
||||
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
||||
auto clf2 = bayesnet::SPODE(0);
|
||||
clf2.setHyperparameters({
|
||||
{"parent", 5},
|
||||
});
|
||||
REQUIRE_THROWS_AS(clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
||||
}
|
@ -7,7 +7,7 @@
|
||||
#include <type_traits>
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <catch2/matchers/catch_matchers.hpp>
|
||||
#include "bayesnet/ensembles/BoostAODE.h"
|
||||
#include "TestUtils.h"
|
||||
|
Loading…
Reference in New Issue
Block a user