Compare commits
5 Commits
3d6b4f0614
...
alphablock
Author | SHA1 | Date | |
---|---|---|---|
b571a4da4d
|
|||
8a9f329ff9
|
|||
e2781ee525
|
|||
56a2d3ead0
|
|||
dc32a0fc47
|
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -18,4 +18,4 @@
|
||||
url = https://github.com/rmontanana/ArffFiles
|
||||
[submodule "lib/mdlp"]
|
||||
path = lib/mdlp
|
||||
url = https://github.com/rmontanana/mdlp
|
||||
url = https://github.com/rmontanana/mdlp
|
@@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
### Added
|
||||
|
||||
- Add a new hyperparameter to the BoostAODE class, *alphablock*, to control the way α is computed, with the last model or with the ensmble built so far. Default value is *false*.
|
||||
- Add a new hyperparameter to the SPODE class, *parent*, to set the root node of the model. If no value is set the root parameter of the constructor is used.
|
||||
- Add a new hyperparameter to the TAN class, *parent*, to set the root node of the model. If not set the first feature is used as root.
|
||||
|
||||
## [1.0.6] 2024-11-23
|
||||
|
||||
|
@@ -18,7 +18,7 @@ The only external dependency is [libtorch](https://pytorch.org/cppdocs/installin
|
||||
|
||||
```bash
|
||||
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
|
||||
unzip libtorch-shared-with-deps-latest.zips
|
||||
unzip libtorch-shared-with-deps-latest.zip
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
@@ -8,14 +8,29 @@
|
||||
|
||||
namespace bayesnet {
|
||||
|
||||
SPODE::SPODE(int root) : Classifier(Network()), root(root) {}
|
||||
SPODE::SPODE(int root) : Classifier(Network()), root(root)
|
||||
{
|
||||
validHyperparameters = { "parent" };
|
||||
}
|
||||
|
||||
void SPODE::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("parent")) {
|
||||
root = hyperparameters["parent"];
|
||||
hyperparameters.erase("parent");
|
||||
}
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
}
|
||||
void SPODE::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// 0. Add all nodes to the model
|
||||
addNodes();
|
||||
// 1. Add edges from the class node to all other nodes
|
||||
// 2. Add edges from the root node to all other nodes
|
||||
if (root >= static_cast<int>(features.size())) {
|
||||
throw std::invalid_argument("The parent node is not in the dataset");
|
||||
}
|
||||
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
|
||||
model.addEdge(className, features[i]);
|
||||
if (i != root) {
|
||||
|
@@ -10,14 +10,15 @@
|
||||
|
||||
namespace bayesnet {
|
||||
class SPODE : public Classifier {
|
||||
private:
|
||||
int root;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
public:
|
||||
explicit SPODE(int root);
|
||||
virtual ~SPODE() = default;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||
std::vector<std::string> graph(const std::string& name = "SPODE") const override;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
private:
|
||||
int root;
|
||||
};
|
||||
}
|
||||
#endif
|
@@ -7,8 +7,20 @@
|
||||
#include "TAN.h"
|
||||
|
||||
namespace bayesnet {
|
||||
TAN::TAN() : Classifier(Network()) {}
|
||||
TAN::TAN() : Classifier(Network())
|
||||
{
|
||||
validHyperparameters = { "parent" };
|
||||
}
|
||||
|
||||
void TAN::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("parent")) {
|
||||
parent = hyperparameters["parent"];
|
||||
hyperparameters.erase("parent");
|
||||
}
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
}
|
||||
void TAN::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// 0. Add all nodes to the model
|
||||
@@ -23,7 +35,10 @@ namespace bayesnet {
|
||||
mi.push_back({ i, mi_value });
|
||||
}
|
||||
sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
|
||||
auto root = mi[mi.size() - 1].first;
|
||||
auto root = parent == -1 ? mi[mi.size() - 1].first : parent;
|
||||
if (root >= static_cast<int>(features.size())) {
|
||||
throw std::invalid_argument("The parent node is not in the dataset");
|
||||
}
|
||||
// 2. Compute mutual information between each feature and the class
|
||||
auto weights_matrix = metrics.conditionalEdge(weights);
|
||||
// 3. Compute the maximum spanning tree
|
||||
|
@@ -9,13 +9,15 @@
|
||||
#include "Classifier.h"
|
||||
namespace bayesnet {
|
||||
class TAN : public Classifier {
|
||||
private:
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
public:
|
||||
TAN();
|
||||
virtual ~TAN() = default;
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||
std::vector<std::string> graph(const std::string& name = "TAN") const override;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
private:
|
||||
int parent = -1;
|
||||
};
|
||||
}
|
||||
#endif
|
Submodule lib/catch2 deleted from 029fe3b460
Submodule lib/folding updated: 2ac43e32ac...9652853d69
2
lib/json
2
lib/json
Submodule lib/json updated: 378e091795...620034ecec
@@ -267,4 +267,36 @@ TEST_CASE("Predict, predict_proba & score without fitting", "[Models]")
|
||||
REQUIRE_THROWS_WITH(clf.predict_proba(raw.Xt), message);
|
||||
REQUIRE_THROWS_WITH(clf.score(raw.Xv, raw.yv), message);
|
||||
REQUIRE_THROWS_WITH(clf.score(raw.Xt, raw.yt), message);
|
||||
}
|
||||
TEST_CASE("TAN & SPODE with hyperparameters", "[Models]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = bayesnet::TAN();
|
||||
clf.setHyperparameters({
|
||||
{"parent", 1},
|
||||
});
|
||||
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score = clf.score(raw.Xv, raw.yv);
|
||||
REQUIRE(score == Catch::Approx(0.973333).epsilon(raw.epsilon));
|
||||
auto clf2 = bayesnet::SPODE(0);
|
||||
clf2.setHyperparameters({
|
||||
{"parent", 1},
|
||||
});
|
||||
clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing);
|
||||
auto score2 = clf2.score(raw.Xv, raw.yv);
|
||||
REQUIRE(score2 == Catch::Approx(0.973333).epsilon(raw.epsilon));
|
||||
}
|
||||
TEST_CASE("TAN & SPODE with invalid hyperparameters", "[Models]")
|
||||
{
|
||||
auto raw = RawDatasets("iris", true);
|
||||
auto clf = bayesnet::TAN();
|
||||
clf.setHyperparameters({
|
||||
{"parent", 5},
|
||||
});
|
||||
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
||||
auto clf2 = bayesnet::SPODE(0);
|
||||
clf2.setHyperparameters({
|
||||
{"parent", 5},
|
||||
});
|
||||
REQUIRE_THROWS_AS(clf2.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
||||
}
|
@@ -7,7 +7,7 @@
|
||||
#include <type_traits>
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include <catch2/matchers/catch_matchers.hpp>
|
||||
#include "bayesnet/ensembles/BoostAODE.h"
|
||||
#include "TestUtils.h"
|
||||
@@ -130,14 +130,21 @@ TEST_CASE("Oddities", "[BoostAODE]")
|
||||
{ { "select_features","IWSS" }, { "threshold", 0.51 } },
|
||||
{ { "select_features","FCBF" }, { "threshold", 1e-8 } },
|
||||
{ { "select_features","FCBF" }, { "threshold", 1.01 } },
|
||||
{ { "alpha_block", true }, { "block_update", true } },
|
||||
{ { "bisection", false }, { "block_update", true } },
|
||||
};
|
||||
for (const auto& hyper : bad_hyper_fit.items()) {
|
||||
INFO("BoostAODE hyper: " << hyper.value().dump());
|
||||
clf.setHyperparameters(hyper.value());
|
||||
REQUIRE_THROWS_AS(clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states, raw.smoothing), std::invalid_argument);
|
||||
}
|
||||
|
||||
auto bad_hyper_fit2 = nlohmann::json{
|
||||
{ { "alpha_block", true }, { "block_update", true } },
|
||||
{ { "bisection", false }, { "block_update", true } },
|
||||
};
|
||||
for (const auto& hyper : bad_hyper_fit2.items()) {
|
||||
INFO("BoostAODE hyper: " << hyper.value().dump());
|
||||
REQUIRE_THROWS_AS(clf.setHyperparameters(hyper.value()), std::invalid_argument);
|
||||
}
|
||||
}
|
||||
TEST_CASE("Bisection Best", "[BoostAODE]")
|
||||
{
|
||||
|
@@ -17,7 +17,7 @@
|
||||
|
||||
std::map<std::string, std::string> modules = {
|
||||
{ "mdlp", "2.0.1" },
|
||||
{ "Folding", "1.1.0" },
|
||||
{ "Folding", "1.1.1" },
|
||||
{ "json", "3.11" },
|
||||
{ "ArffFiles", "1.1.0" }
|
||||
};
|
||||
|
Submodule tests/lib/catch2 updated: 506276c592...0321d2fce3
Reference in New Issue
Block a user