Implement SPnDE and A2DE
This commit is contained in:
parent
8115f25c06
commit
f806015b29
37
bayesnet/classifiers/SPnDE.cc
Normal file
37
bayesnet/classifiers/SPnDE.cc
Normal file
@ -0,0 +1,37 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include "SPnDE.h"
|
||||
|
||||
namespace bayesnet {
|
||||
|
||||
SPnDE::SPnDE(std::vector<int> parents) : Classifier(Network()), parents(parents) {}
|
||||
|
||||
void SPnDE::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
// 0. Add all nodes to the model
|
||||
addNodes();
|
||||
std::vector<int> attributes;
|
||||
for (int i = 0; i < static_cast<int>(features.size()); ++i) {
|
||||
if (std::find(parents.begin(), parents.end(), i) != parents.end()) {
|
||||
attributes.push_back(i);
|
||||
}
|
||||
}
|
||||
// 1. Add edges from the class node to all other nodes
|
||||
// 2. Add edges from the parents nodes to all other nodes
|
||||
for (const auto& attribute : attributes) {
|
||||
model.addEdge(className, features[attribute]);
|
||||
for (const auto& root : parents) {
|
||||
model.addEdge(features[root], features[attribute]);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::vector<std::string> SPnDE::graph(const std::string& name) const
|
||||
{
|
||||
return model.graph(name);
|
||||
}
|
||||
|
||||
}
|
26
bayesnet/classifiers/SPnDE.h
Normal file
26
bayesnet/classifiers/SPnDE.h
Normal file
@ -0,0 +1,26 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#ifndef SPnDE_H
|
||||
#define SPnDE_H
|
||||
#include <vector>
|
||||
#include "Classifier.h"
|
||||
|
||||
namespace bayesnet {
|
||||
class SPnDE : public Classifier {
|
||||
public:
|
||||
explicit SPnDE(std::vector<int> parents);
|
||||
virtual ~SPnDE() = default;
|
||||
std::vector<std::string> graph(const std::string& name = "SPnDE") const override;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
private:
|
||||
std::vector<int> parents;
|
||||
|
||||
|
||||
};
|
||||
}
|
||||
#endif
|
40
bayesnet/ensembles/A2DE.cc
Normal file
40
bayesnet/ensembles/A2DE.cc
Normal file
@ -0,0 +1,40 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include "A2DE.h"
|
||||
|
||||
namespace bayesnet {
|
||||
A2DE::A2DE(bool predict_voting) : Ensemble(predict_voting)
|
||||
{
|
||||
validHyperparameters = { "predict_voting" };
|
||||
|
||||
}
|
||||
void A2DE::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("predict_voting")) {
|
||||
predict_voting = hyperparameters["predict_voting"];
|
||||
hyperparameters.erase("predict_voting");
|
||||
}
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
}
|
||||
void A2DE::buildModel(const torch::Tensor& weights)
|
||||
{
|
||||
models.clear();
|
||||
significanceModels.clear();
|
||||
for (int i = 0; i < features.size() - 1; ++i) {
|
||||
for (int j = i + 1; j < features.size(); ++j) {
|
||||
models.push_back(std::make_unique<SPnDE>(std::vector<int>({ i, j })));
|
||||
}
|
||||
}
|
||||
n_models = models.size();
|
||||
significanceModels = std::vector<double>(n_models, 1.0);
|
||||
}
|
||||
std::vector<std::string> A2DE::graph(const std::string& title) const
|
||||
{
|
||||
return Ensemble::graph(title);
|
||||
}
|
||||
}
|
22
bayesnet/ensembles/A2DE.h
Normal file
22
bayesnet/ensembles/A2DE.h
Normal file
@ -0,0 +1,22 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#ifndef A2DE_H
|
||||
#define A2DE_H
|
||||
#include "bayesnet/classifiers/SPnDE.h"
|
||||
#include "Ensemble.h"
|
||||
namespace bayesnet {
|
||||
class A2DE : public Ensemble {
|
||||
public:
|
||||
A2DE(bool predict_voting = false);
|
||||
virtual ~A2DE() {};
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||
std::vector<std::string> graph(const std::string& title = "A2DE") const override;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
};
|
||||
}
|
||||
#endif
|
1
lib/catch2
Submodule
1
lib/catch2
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 029fe3b4609dd84cd939b73357f37bbb75bcf82f
|
@ -21,4 +21,5 @@ if(ENABLE_TESTING)
|
||||
add_test(NAME Ensemble COMMAND TestBayesNet "[Ensemble]")
|
||||
add_test(NAME Models COMMAND TestBayesNet "[Models]")
|
||||
add_test(NAME BoostAODE COMMAND TestBayesNet "[BoostAODE]")
|
||||
add_test(NAME A2DE COMMAND TestBayesNet "[A2DE]")
|
||||
endif(ENABLE_TESTING)
|
||||
|
26
tests/TestA2DE.cc
Normal file
26
tests/TestA2DE.cc
Normal file
@ -0,0 +1,26 @@
|
||||
// ***************************************************************
|
||||
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include <type_traits>
|
||||
#include <catch2/catch_test_macros.hpp>
|
||||
#include <catch2/catch_approx.hpp>
|
||||
#include <catch2/generators/catch_generators.hpp>
|
||||
#include "bayesnet/ensembles/A2DE.h"
|
||||
#include "TestUtils.h"
|
||||
|
||||
|
||||
TEST_CASE("Fit and Score", "[A2DE]")
|
||||
{
|
||||
auto raw = RawDatasets("glass", true);
|
||||
auto clf = bayesnet::A2DE();
|
||||
clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states);
|
||||
std::cout << "Score A2DE: " << clf.score(raw.Xv, raw.yv) << std::endl;
|
||||
// REQUIRE(clf.getNumberOfNodes() == 90);
|
||||
// REQUIRE(clf.getNumberOfEdges() == 153);
|
||||
// REQUIRE(clf.getNotes().size() == 2);
|
||||
// REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 9 with CFS");
|
||||
// REQUIRE(clf.getNotes()[1] == "Number of models: 9");
|
||||
}
|
Loading…
Reference in New Issue
Block a user