Compare commits
5 Commits
3d6b4f0614
...
alphablock
Author | SHA1 | Date | |
---|---|---|---|
b571a4da4d
|
|||
8a9f329ff9
|
|||
e2781ee525
|
|||
56a2d3ead0
|
|||
dc32a0fc47
|
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -18,4 +18,4 @@
|
|||||||
url = https://github.com/rmontanana/ArffFiles
|
url = https://github.com/rmontanana/ArffFiles
|
||||||
[submodule "lib/mdlp"]
|
[submodule "lib/mdlp"]
|
||||||
path = lib/mdlp
|
path = lib/mdlp
|
||||||
url = https://github.com/rmontanana/mdlp
|
url = https://github.com/rmontanana/mdlp
|
@@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
### Added
|
### Added
|
||||||
|
|
||||||
- Add a new hyperparameter to the BoostAODE class, *alphablock*, to control the way α is computed, with the last model or with the ensmble built so far. Default value is *false*.
|
- Add a new hyperparameter to the BoostAODE class, *alphablock*, to control the way α is computed, with the last model or with the ensmble built so far. Default value is *false*.
|
||||||
|
- Add a new hyperparameter to the SPODE class, *parent*, to set the root node of the model. If no value is set the root parameter of the constructor is used.
|
||||||
|
- Add a new hyperparameter to the TAN class, *parent*, to set the root node of the model. If not set the first feature is used as root.
|
||||||
|
|
||||||
## [1.0.6] 2024-11-23
|
## [1.0.6] 2024-11-23
|
||||||
|
|
||||||
|
@@ -18,7 +18,7 @@ The only external dependency is [libtorch](https://pytorch.org/cppdocs/installin
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
|
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
|
||||||
unzip libtorch-shared-with-deps-latest.zips
|
unzip libtorch-shared-with-deps-latest.zip
|
||||||
```
|
```
|
||||||
|
|
||||||
## Setup
|
## Setup
|
||||||
|
@@ -8,14 +8,29 @@
|
|||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
|
|
||||||
SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
|
SPODE::SPODE(int root) : Classifier(Network()), root(root)
|
||||||
|
{
|
||||||
|
validHyperparameters = { "parent" };
|
||||||
|
}
|
||||||
|
|
||||||
|
void SPODE::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||||
|
{
|
||||||
|
auto hyperparameters = hyperparameters_;
|
||||||
|
if (hyperparameters.contains("parent")) {
|
||||||
|
root = hyperparameters["parent"];
|
||||||
|
hyperparameters.erase("parent");
|
||||||
|
}
|
||||||
|
Classifier::setHyperparameters(hyperparameters);
|
||||||
|
}
|
||||||
void SPODE::buildModel(const torch::Tensor& weights)
|
void SPODE::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
// 0. Add all nodes to the model
|
// 0. Add all nodes to the model
|
||||||
addNodes();
|
addNodes();
|
||||||
// 1. Add edges from the class node to all other nodes
|
// 1. Add edges from the class node to all other nodes
|
||||||
// 2. Add edges from the root node to all other nodes
|
// 2. Add edges from the root node to all other nodes
|
||||||
|
if (root >= static_cast<int>(features.size())) {
|
||||||
|
throw std::invalid_argument("The parent node is not in the dataset");
|
||||||
|
}
|
||||||
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
|
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
|
||||||
model.addEdge(className, features[i]);
|
model.addEdge(className, features[i]);
|
||||||
if (i != root) {
|
if (i != root) {
|
||||||
|
@@ -10,14 +10,15 @@
|
|||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class SPODE : public Classifier {
|
class SPODE : public Classifier {
|
||||||
private:
|
|
||||||
int root;
|
|
||||||
protected:
|
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
|
||||||
public:
|
public:
|
||||||
explicit SPODE(int root);
|
explicit SPODE(int root);
|
||||||
virtual ~SPODE() = default;
|
virtual ~SPODE() = default;
|
||||||
|
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||||
std::vector<std::string> graph(const std::string& name = "SPODE") const override;
|
std::vector<std::string> graph(const std::string& name = "SPODE") const override;
|
||||||
|
protected:
|
||||||
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
|
private:
|
||||||
|
int root;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
@@ -7,8 +7,20 @@
|
|||||||
#include "TAN.h"
|
#include "TAN.h"
|
||||||
|
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
TAN::TAN() : Classifier(Network()) {}
|
TAN::TAN() : Classifier(Network())
|
||||||
|
{
|
||||||
|
validHyperparameters = { "parent" };
|
||||||
|
}
|
||||||
|
|
||||||
|
void TAN::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||||
|
{
|
||||||
|
auto hyperparameters = hyperparameters_;
|
||||||
|
if (hyperparameters.contains("parent")) {
|
||||||
|
parent = hyperparameters["parent"];
|
||||||
|
hyperparameters.erase("parent");
|
||||||
|
}
|
||||||
|
Classifier::setHyperparameters(hyperparameters);
|
||||||
|
}
|
||||||
void TAN::buildModel(const torch::Tensor& weights)
|
void TAN::buildModel(const torch::Tensor& weights)
|
||||||
{
|
{
|
||||||
// 0. Add all nodes to the model
|
// 0. Add all nodes to the model
|
||||||
@@ -23,7 +35,10 @@ namespace bayesnet {
|
|||||||
mi.push_back({ i, mi_value });
|
mi.push_back({ i, mi_value });
|
||||||
}
|
}
|
||||||
sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
|
sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
|
||||||
auto root = mi[mi.size() - 1].first;
|
auto root = parent == -1 ? mi[mi.size() - 1].first : parent;
|
||||||
|
if (root >= static_cast<int>(features.size())) {
|
||||||
|
throw std::invalid_argument("The parent node is not in the dataset");
|
||||||
|
}
|
||||||
// 2. Compute mutual information between each feature and the class
|
// 2. Compute mutual information between each feature and the class
|
||||||
auto weights_matrix = metrics.conditionalEdge(weights);
|
auto weights_matrix = metrics.conditionalEdge(weights);
|
||||||
// 3. Compute the maximum spanning tree
|
// 3. Compute the maximum spanning tree
|
||||||
|
@@ -9,13 +9,15 @@
|
|||||||
#include "Classifier.h"
|
#include "Classifier.h"
|
||||||
namespace bayesnet {
|
namespace bayesnet {
|
||||||
class TAN : public Classifier {
|
class TAN : public Classifier {
|
||||||
private:
|
|
||||||
protected:
|
|
||||||
void buildModel(const torch::Tensor& weights) override;
|
|
||||||
public:
|
public:
|
||||||
TAN();
|
TAN();
|
||||||
virtual ~TAN() = default;
|
virtual ~TAN() = default;
|
||||||
|
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||||
std::vector<std::string> graph(const std::string& name = "TAN") const override;
|
std::vector<std::string> graph(const std::string& name = "TAN") const override;
|
||||||
|
protected:
|
||||||
|
void buildModel(const torch::Tensor& weights) override;
|
||||||
|
private:
|
||||||
|
int parent = -1;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
Submodule lib/catch2 deleted from 029fe3b460
Submodule lib/folding updated: 2ac43e32ac...9652853d69
2
lib/json
2
lib/json
Submodule lib/json updated: 378e091795...620034ecec
@@ -267,4 +267,36 @@ TEST_CASE("Predict, predict_proba & score without fitting", "[Models]")
|
|||||||
REQUIRE_THROWS_WITH(clf.predict_proba(raw.Xt), message);
|
REQUIRE_THROWS_WITH(clf.predict_proba(raw.Xt), message);
|
||||||
REQUIRE_THROWS_WITH(clf.score(raw.Xv, raw.yv), message);
|
REQUIRE_THROWS_WITH(clf.score(raw.Xv, raw.yv), message);
|
||||||
REQUIRE_THROWS_WITH(clf.score(raw.Xt, raw.yt), message);
|
REQUIRE_THROWS_WITH(clf.score(raw.Xt, raw.yt), message);
|
||||||
|
}
|
||||||
|
TEST_CASE("TAN & SPODE with hyperparameters", "[Models]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
auto clf = bayesnet::TAN();
|
||||||
|
clf.setHyperparameters({
|
||||||
|
{"parent", 1},
|
||||||
|
});
|
||||||
|
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||||
|
auto score = clf.score(raw.Xv, raw.yv);
|
||||||
|
REQUIRE(score == Catch::Approx(0.973333).epsilon(raw.epsilon));
|
||||||
|
auto clf2 = bayesnet::SPODE(0);
|
||||||
|
clf2.setHyperparameters({
|
||||||
|
{"parent", 1},
|
||||||
|
});
|
||||||
|
clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||||
|
auto score2 = clf2.score(raw.Xv, raw.yv);
|
||||||
|
REQUIRE(score2 == Catch::Approx(0.973333).epsilon(raw.epsilon));
|
||||||
|
}
|
||||||
|
TEST_CASE("TAN & SPODE with invalid hyperparameters", "[Models]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
auto clf = bayesnet::TAN();
|
||||||
|
clf.setHyperparameters({
|
||||||
|
{"parent", 5},
|
||||||
|
});
|
||||||
|
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
||||||
|
auto clf2 = bayesnet::SPODE(0);
|
||||||
|
clf2.setHyperparameters({
|
||||||
|
{"parent", 5},
|
||||||
|
});
|
||||||
|
REQUIRE_THROWS_AS(clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
||||||
}
|
}
|
@@ -7,7 +7,7 @@
|
|||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <catch2/catch_test_macros.hpp>
|
#include <catch2/catch_test_macros.hpp>
|
||||||
#include <catch2/catch_approx.hpp>
|
#include <catch2/catch_approx.hpp>
|
||||||
#include <catch2/generators/catch_generators.hpp>
|
#include <catch2/generators/catch_generators.hpp>
|
||||||
#include <catch2/matchers/catch_matchers.hpp>
|
#include <catch2/matchers/catch_matchers.hpp>
|
||||||
#include "bayesnet/ensembles/BoostAODE.h"
|
#include "bayesnet/ensembles/BoostAODE.h"
|
||||||
#include "TestUtils.h"
|
#include "TestUtils.h"
|
||||||
@@ -130,14 +130,21 @@ TEST_CASE("Oddities", "[BoostAODE]")
|
|||||||
{ { "select_features","IWSS" }, { "threshold", 0.51 } },
|
{ { "select_features","IWSS" }, { "threshold", 0.51 } },
|
||||||
{ { "select_features","FCBF" }, { "threshold", 1e-8 } },
|
{ { "select_features","FCBF" }, { "threshold", 1e-8 } },
|
||||||
{ { "select_features","FCBF" }, { "threshold", 1.01 } },
|
{ { "select_features","FCBF" }, { "threshold", 1.01 } },
|
||||||
{ { "alpha_block", true }, { "block_update", true } },
|
|
||||||
{ { "bisection", false }, { "block_update", true } },
|
|
||||||
};
|
};
|
||||||
for (const auto& hyper : bad_hyper_fit.items()) {
|
for (const auto& hyper : bad_hyper_fit.items()) {
|
||||||
INFO("BoostAODE hyper: " << hyper.value().dump());
|
INFO("BoostAODE hyper: " << hyper.value().dump());
|
||||||
clf.setHyperparameters(hyper.value());
|
clf.setHyperparameters(hyper.value());
|
||||||
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto bad_hyper_fit2 = nlohmann::json{
|
||||||
|
{ { "alpha_block", true }, { "block_update", true } },
|
||||||
|
{ { "bisection", false }, { "block_update", true } },
|
||||||
|
};
|
||||||
|
for (const auto& hyper : bad_hyper_fit2.items()) {
|
||||||
|
INFO("BoostAODE hyper: " << hyper.value().dump());
|
||||||
|
REQUIRE_THROWS_AS(clf.setHyperparameters(hyper.value()), std::invalid_argument);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
TEST_CASE("Bisection Best", "[BoostAODE]")
|
TEST_CASE("Bisection Best", "[BoostAODE]")
|
||||||
{
|
{
|
||||||
|
@@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
std::map<std::string, std::string> modules = {
|
std::map<std::string, std::string> modules = {
|
||||||
{ "mdlp", "2.0.1" },
|
{ "mdlp", "2.0.1" },
|
||||||
{ "Folding", "1.1.0" },
|
{ "Folding", "1.1.1" },
|
||||||
{ "json", "3.11" },
|
{ "json", "3.11" },
|
||||||
{ "ArffFiles", "1.1.0" }
|
{ "ArffFiles", "1.1.0" }
|
||||||
};
|
};
|
||||||
|
Submodule tests/lib/catch2 updated: 506276c592...0321d2fce3
Reference in New Issue
Block a user