diff --git a/bayesnet/classifiers/SPnDE.cc b/bayesnet/classifiers/SPnDE.cc new file mode 100644 index 0000000..4785b26 --- /dev/null +++ b/bayesnet/classifiers/SPnDE.cc @@ -0,0 +1,37 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include "SPnDE.h" + +namespace bayesnet { + + SPnDE::SPnDE(std::vector parents) : Classifier(Network()), parents(parents) {} + + void SPnDE::buildModel(const torch::Tensor& weights) + { + // 0. Add all nodes to the model + addNodes(); + std::vector attributes; + for (int i = 0; i < static_cast(features.size()); ++i) { + if (std::find(parents.begin(), parents.end(), i) != parents.end()) { + attributes.push_back(i); + } + } + // 1. Add edges from the class node to all other nodes + // 2. Add edges from the parents nodes to all other nodes + for (const auto& attribute : attributes) { + model.addEdge(className, features[attribute]); + for (const auto& root : parents) { + model.addEdge(features[root], features[attribute]); + } + } + } + std::vector SPnDE::graph(const std::string& name) const + { + return model.graph(name); + } + +} \ No newline at end of file diff --git a/bayesnet/classifiers/SPnDE.h b/bayesnet/classifiers/SPnDE.h new file mode 100644 index 0000000..9ab79ad --- /dev/null +++ b/bayesnet/classifiers/SPnDE.h @@ -0,0 +1,26 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#ifndef SPnDE_H +#define SPnDE_H +#include +#include "Classifier.h" + +namespace bayesnet { + class SPnDE : public Classifier { + public: + explicit SPnDE(std::vector parents); + virtual ~SPnDE() = default; + std::vector graph(const std::string& name = "SPnDE") const override; + protected: + void buildModel(const torch::Tensor& weights) override; + private: + std::vector parents; + + + }; +} +#endif \ No newline at end of file diff --git a/bayesnet/ensembles/A2DE.cc b/bayesnet/ensembles/A2DE.cc new file mode 100644 index 0000000..27e709c --- /dev/null +++ b/bayesnet/ensembles/A2DE.cc @@ -0,0 +1,40 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include "A2DE.h" + +namespace bayesnet { + A2DE::A2DE(bool predict_voting) : Ensemble(predict_voting) + { + validHyperparameters = { "predict_voting" }; + + } + void A2DE::setHyperparameters(const nlohmann::json& hyperparameters_) + { + auto hyperparameters = hyperparameters_; + if (hyperparameters.contains("predict_voting")) { + predict_voting = hyperparameters["predict_voting"]; + hyperparameters.erase("predict_voting"); + } + Classifier::setHyperparameters(hyperparameters); + } + void A2DE::buildModel(const torch::Tensor& weights) + { + models.clear(); + significanceModels.clear(); + for (int i = 0; i < features.size() - 1; ++i) { + for (int j = i + 1; j < features.size(); ++j) { + models.push_back(std::make_unique(std::vector({ i, j }))); + } + } + n_models = models.size(); + significanceModels = std::vector(n_models, 1.0); + } + std::vector A2DE::graph(const std::string& title) const + { + return Ensemble::graph(title); + } +} \ No newline at end of file diff --git a/bayesnet/ensembles/A2DE.h b/bayesnet/ensembles/A2DE.h new file mode 100644 index 0000000..2307dbd --- /dev/null +++ b/bayesnet/ensembles/A2DE.h @@ -0,0 +1,22 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#ifndef A2DE_H +#define A2DE_H +#include "bayesnet/classifiers/SPnDE.h" +#include "Ensemble.h" +namespace bayesnet { + class A2DE : public Ensemble { + public: + A2DE(bool predict_voting = false); + virtual ~A2DE() {}; + void setHyperparameters(const nlohmann::json& hyperparameters) override; + std::vector graph(const std::string& title = "A2DE") const override; + protected: + void buildModel(const torch::Tensor& weights) override; + }; +} +#endif \ No newline at end of file diff --git a/lib/catch2 b/lib/catch2 new file mode 160000 index 0000000..029fe3b --- /dev/null +++ b/lib/catch2 @@ -0,0 +1 @@ +Subproject commit 029fe3b4609dd84cd939b73357f37bbb75bcf82f diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1e116c0..b51bff3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -21,4 +21,5 @@ if(ENABLE_TESTING) add_test(NAME Ensemble COMMAND TestBayesNet "[Ensemble]") add_test(NAME Models COMMAND TestBayesNet "[Models]") add_test(NAME BoostAODE COMMAND TestBayesNet "[BoostAODE]") + add_test(NAME A2DE COMMAND TestBayesNet "[A2DE]") endif(ENABLE_TESTING) diff --git a/tests/TestA2DE.cc b/tests/TestA2DE.cc new file mode 100644 index 0000000..8cd128e --- /dev/null +++ b/tests/TestA2DE.cc @@ -0,0 +1,26 @@ +// *************************************************************** +// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez +// SPDX-FileType: SOURCE +// SPDX-License-Identifier: MIT +// *************************************************************** + +#include +#include +#include +#include +#include "bayesnet/ensembles/A2DE.h" +#include "TestUtils.h" + + +TEST_CASE("Fit and Score", "[A2DE]") +{ + auto raw = RawDatasets("glass", true); + auto clf = bayesnet::A2DE(); + clf.fit(raw.Xv, raw.yv, raw.features, raw.className, raw.states); + std::cout << "Score A2DE: " << clf.score(raw.Xv, raw.yv) << std::endl; + // REQUIRE(clf.getNumberOfNodes() == 90); + // REQUIRE(clf.getNumberOfEdges() == 153); + // REQUIRE(clf.getNotes().size() == 2); + // REQUIRE(clf.getNotes()[0] == "Used features in initialization: 6 of 9 with CFS"); + // REQUIRE(clf.getNotes()[1] == "Number of models: 9"); +}