diff --git a/sample/main.cc b/sample/main.cc index d496a29..ee453ca 100644 --- a/sample/main.cc +++ b/sample/main.cc @@ -10,6 +10,7 @@ #include "KDB.h" #include "SPODE.h" #include "AODE.h" +#include "TAN.h" using namespace std; @@ -282,5 +283,13 @@ int main(int argc, char** argv) } cout << "Score: " << aode.score(Xd, y) << endl; cout << "****************** AODE ******************" << endl; + cout << "****************** TAN ******************" << endl; + auto tan = bayesnet::TAN(); + tan.fit(Xd, y, features, className, states); + for (auto line : tan.show()) { + cout << line << endl; + } + cout << "Score: " << tan.score(Xd, y) << endl; + cout << "****************** TAN ******************" << endl; return 0; } \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8bb813a..7845477 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(BayesNet utils.cc Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc) +add_library(BayesNet utils.cc Network.cc Node.cc Metrics.cc BaseClassifier.cc KDB.cc TAN.cc SPODE.cc Ensemble.cc AODE.cc Mst.cc) target_link_libraries(BayesNet "${TORCH_LIBRARIES}") \ No newline at end of file diff --git a/src/Metrics.cc b/src/Metrics.cc index a150739..887706a 100644 --- a/src/Metrics.cc +++ b/src/Metrics.cc @@ -1,4 +1,5 @@ #include "Metrics.hpp" +#include "Mst.h" using namespace std; namespace bayesnet { Metrics::Metrics(torch::Tensor& samples, vector& features, string& className, int classNumStates) @@ -121,14 +122,11 @@ namespace bayesnet { and the indices of the weights as nodes of this square matrix using Kruskal algorithm */ - vector> Metrics::maximumSpanningTree(int root, Tensor& weights) + vector> Metrics::maximumSpanningTree(vector features, Tensor& weights, int root) { auto result = vector>(); + auto mst = MST(features, weights, root); + return mst.maximumSpanningTree(); - - - - - return result; } } \ No newline at end of file diff --git a/src/Metrics.hpp b/src/Metrics.hpp index f320fed..f8557a0 100644 --- a/src/Metrics.hpp +++ b/src/Metrics.hpp @@ -22,7 +22,7 @@ namespace bayesnet { vector conditionalEdgeWeights(); Tensor conditionalEdge(); vector> doCombinations(const vector&); - vector> maximumSpanningTree(int root, Tensor& weights); + vector> maximumSpanningTree(vector features, Tensor& weights, int root); }; } #endif \ No newline at end of file diff --git a/src/Mst.cc b/src/Mst.cc new file mode 100644 index 0000000..ce75376 --- /dev/null +++ b/src/Mst.cc @@ -0,0 +1,116 @@ +#include "Mst.h" +#include +/* + Based on the code from https://www.softwaretestinghelp.com/minimum-spanning-tree-tutorial/ + +*/ + +namespace bayesnet { + using namespace std; + Graph::Graph(int V) + { + parent = vector(V); + for (int i = 0; i < V; i++) + parent[i] = i; + G.clear(); + T.clear(); + } + void Graph::addEdge(int u, int v, float wt) + { + G.push_back({ wt, { u, v } }); + } + int Graph::find_set(int i) + { + // If i is the parent of itself + if (i == parent[i]) + return i; + else + //else recursively find the parent of i + return find_set(parent[i]); + } + void Graph::union_set(int u, int v) + { + parent[u] = parent[v]; + } + void Graph::kruskal_algorithm() + { + int i, uSt, vEd; + // sort the edges ordered on decreasing weight + sort(G.begin(), G.end(), [](auto& left, auto& right) {return left.first > right.first;}); + for (i = 0; i < G.size(); i++) { + uSt = find_set(G[i].second.first); + vEd = find_set(G[i].second.second); + if (uSt != vEd) { + T.push_back(G[i]); // add to mst vector + union_set(uSt, vEd); + } + } + } + void Graph::display_mst() + { + cout << "Edge :" << " Weight" << endl; + for (int i = 0; i < T.size(); i++) { + cout << T[i].second.first << " - " << T[i].second.second << " : " + << T[i].first; + cout << endl; + } + } + + vector> reorder(vector>> T, int root_original) + { + auto result = vector>(); + auto visited = vector(); + auto nextVariables = unordered_set(); + nextVariables.emplace(root_original); + while (nextVariables.size() > 0) { + int root = *nextVariables.begin(); + nextVariables.erase(nextVariables.begin()); + for (int i = 0; i < T.size(); ++i) { + auto [weight, edge] = T[i]; + auto [from, to] = edge; + if (from == root || to == root) { + visited.insert(visited.begin(), i); + if (from == root) { + result.push_back({ from, to }); + nextVariables.emplace(to); + } else { + result.push_back({ to, from }); + nextVariables.emplace(from); + } + } + } + // Remove visited + for (int i = 0; i < visited.size(); ++i) { + T.erase(T.begin() + visited[i]); + } + visited.clear(); + } + if (T.size() > 0) { + for (int i = 0; i < T.size(); ++i) { + auto [weight, edge] = T[i]; + auto [from, to] = edge; + result.push_back({ from, to }); + } + } + return result; + } + + MST::MST(vector& features, Tensor& weights, int root) : features(features), weights(weights), root(root) {} + vector> MST::maximumSpanningTree() + { + auto num_features = features.size(); + Graph g(num_features); + + // Make a complete graph + for (int i = 0; i < num_features - 1; ++i) { + for (int j = i; j < num_features; ++j) { + g.addEdge(i, j, weights[i][j].item()); + } + } + g.kruskal_algorithm(); + //g.display_mst(); + auto mst = g.get_mst(); + return reorder(mst, root); + } + +} \ No newline at end of file diff --git a/src/Mst.h b/src/Mst.h new file mode 100644 index 0000000..15b0dbb --- /dev/null +++ b/src/Mst.h @@ -0,0 +1,35 @@ +#ifndef MST_H +#define MST_H +#include +#include +#include +namespace bayesnet { + using namespace std; + using namespace torch; + class MST { + private: + Tensor weights; + vector features; + int root; + public: + MST() = default; + MST(vector& features, Tensor& weights, int root); + vector> maximumSpanningTree(); + }; + class Graph { + private: + int V; // number of nodes in graph + vector >> G; // vector for graph + vector >> T; // vector for mst + vector parent; + public: + Graph(int V); + void addEdge(int u, int v, float wt); + int find_set(int i); + void union_set(int u, int v); + void kruskal_algorithm(); + void display_mst(); + vector >> get_mst() { return T; } + }; +} +#endif \ No newline at end of file diff --git a/src/TAN.cc b/src/TAN.cc index e33e715..bb0d561 100644 --- a/src/TAN.cc +++ b/src/TAN.cc @@ -19,16 +19,20 @@ namespace bayesnet { auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset); mi.push_back({ i, mi_value }); } - sort(mi.begin(), mi.end()); + sort(mi.begin(), mi.end(), [](auto& left, auto& right) {return left.second < right.second;}); auto root = mi[mi.size() - 1].first; // 2. Compute mutual information between each feature and the class auto weights = metrics.conditionalEdge(); // 3. Compute the maximum spanning tree - auto mst = metrics.maximumSpanningTree(root, weights); + auto mst = metrics.maximumSpanningTree(features, weights, root); // 4. Add edges from the maximum spanning tree to the model for (auto i = 0; i < mst.size(); ++i) { auto [from, to] = mst[i]; model.addEdge(features[from], features[to]); } + // 5. Add edges from the class to all features + for (auto feature : features) { + model.addEdge(className, feature); + } } } \ No newline at end of file