From e8b8fa29c86dc55b3f0a6cffb4bb4d2051313e86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Fri, 14 Jul 2023 12:59:47 +0200 Subject: [PATCH] Add SPODE --- .vscode/settings.json | 8 +++++++- src/CMakeLists.txt | 2 +- src/Metrics.cc | 2 +- src/Metrics.hpp | 2 +- src/SPODE.cc | 20 ++++++++++++++++++++ src/SPODE.h | 14 ++++++++++++++ src/TAN.cc | 21 +++++++++++++++------ 7 files changed, 59 insertions(+), 10 deletions(-) create mode 100644 src/SPODE.cc create mode 100644 src/SPODE.h diff --git a/.vscode/settings.json b/.vscode/settings.json index 86e20ee..e0e4a2a 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -90,7 +90,13 @@ "format": "cpp", "valarray": "cpp", "regex": "cpp", - "span": "cpp" + "span": "cpp", + "cfenv": "cpp", + "cinttypes": "cpp", + "csetjmp": "cpp", + "future": "cpp", + "queue": "cpp", + "typeindex": "cpp" }, "cmake.configureOnOpen": false, "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b76c503..59728d5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc) +add_library(BayesNet Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc) target_link_libraries(BayesNet "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Metrics.cc b/src/Metrics.cc index 5c1343a..a150739 100644 --- a/src/Metrics.cc +++ b/src/Metrics.cc @@ -121,7 +121,7 @@ namespace bayesnet { and the indices of the weights as nodes of this square matrix using Kruskal algorithm */ - vector> Metrics::maximumSpanningTree(Tensor& weights) + vector> Metrics::maximumSpanningTree(int root, Tensor& weights) { auto result = vector>(); diff --git a/src/Metrics.hpp b/src/Metrics.hpp index 2934f83..f320fed 100644 --- a/src/Metrics.hpp +++ b/src/Metrics.hpp @@ -22,7 +22,7 @@ namespace bayesnet { vector conditionalEdgeWeights(); Tensor conditionalEdge(); vector> doCombinations(const vector&); - vector> maximumSpanningTree(Tensor& weights); + vector> maximumSpanningTree(int root, Tensor& weights); }; } #endif \ No newline at end of file diff --git a/src/SPODE.cc b/src/SPODE.cc new file mode 100644 index 0000000..4b0431b --- /dev/null +++ b/src/SPODE.cc @@ -0,0 +1,20 @@ +#include "SPODE.h" + +namespace bayesnet { + + SPODE::SPODE(int root) : BaseClassifier(Network()), root(root) {} + + void SPODE::train() + { + // 0. Add all nodes to the model + addNodes(); + // 1. Add edges from the class node to all other nodes + // 2. Add edges from the root node to all other nodes + for (int i = 0; i < static_cast(features.size()); ++i) { + model.addEdge(className, features[i]); + if (i != root) { + model.addEdge(features[root], features[i]); + } + } + } +} \ No newline at end of file diff --git a/src/SPODE.h b/src/SPODE.h new file mode 100644 index 0000000..f796d19 --- /dev/null +++ b/src/SPODE.h @@ -0,0 +1,14 @@ +#ifndef SPODE_H +#define SPODE_H +#include "BaseClassifier.h" +namespace bayesnet { + class SPODE : public BaseClassifier { + private: + int root; + protected: + void train() override; + public: + SPODE(int root); + }; +} +#endif \ No newline at end of file diff --git a/src/TAN.cc b/src/TAN.cc index d490a23..e33e715 100644 --- a/src/TAN.cc +++ b/src/TAN.cc @@ -10,16 +10,25 @@ namespace bayesnet { { // 0. Add all nodes to the model addNodes(); - // 1. Compute mutual information between each feature and the class + // 1. Compute mutual information between each feature and the class and set the root node + // as the highest mutual information with the class + auto mi = vector >(); + Tensor class_dataset = dataset.index({ "...", -1 }); + for (int i = 0; i < static_cast(features.size()); ++i) { + Tensor feature_dataset = dataset.index({ "...", i }); + auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset); + mi.push_back({ i, mi_value }); + } + sort(mi.begin(), mi.end()); + auto root = mi[mi.size() - 1].first; + // 2. Compute mutual information between each feature and the class auto weights = metrics.conditionalEdge(); - // 2. Compute the maximum spanning tree - auto mst = metrics.maximumSpanningTree(weights); - // 3. Add edges from the maximum spanning tree to the model + // 3. Compute the maximum spanning tree + auto mst = metrics.maximumSpanningTree(root, weights); + // 4. Add edges from the maximum spanning tree to the model for (auto i = 0; i < mst.size(); ++i) { auto [from, to] = mst[i]; model.addEdge(features[from], features[to]); } - } - } \ No newline at end of file