// *************************************************************** // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** #include #include #include #include "Mst.h" /* Based on the code from https://www.softwaretestinghelp.com/minimum-spanning-tree-tutorial/ */ namespace bayesnet { Graph::Graph(int V) : V(V), parent(std::vector(V)) { for (int i = 0; i < V; i++) parent[i] = i; G.clear(); T.clear(); } void Graph::addEdge(int u, int v, float wt) { G.push_back({ wt, { u, v } }); } int Graph::find_set(int i) { // If i is the parent of itself if (i == parent[i]) return i; else //else recursively find the parent of i return find_set(parent[i]); } void Graph::union_set(int u, int v) { parent[u] = parent[v]; } void Graph::kruskal_algorithm() { // sort the edges ordered on decreasing weight stable_sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;}); for (int i = 0; i < G.size(); i++) { int uSt, vEd; uSt = find_set(G[i].second.first); vEd = find_set(G[i].second.second); if (uSt != vEd) { T.push_back(G[i]); // add to mst std::vector union_set(uSt, vEd); } } } void insertElement(std::list& variables, int variable) { if (std::find(variables.begin(), variables.end(), variable) == variables.end()) { variables.push_front(variable); } } std::vector> reorder(std::vector>> T, int root_original) { // Create the edges of a DAG from the MST // replacing unordered_set with list because unordered_set cannot guarantee the order of the elements inserted auto result = std::vector>(); auto visited = std::vector(); auto nextVariables = std::list(); nextVariables.push_front(root_original); while (nextVariables.size() > 0) { int root = nextVariables.front(); nextVariables.pop_front(); for (int i = 0; i < T.size(); ++i) { auto [weight, edge] = T[i]; auto [from, to] = edge; if (from == root || to == root) { visited.insert(visited.begin(), i); if (from == root) { result.push_back({ from, to }); insertElement(nextVariables, to); } else { result.push_back({ to, from }); insertElement(nextVariables, from); } } } // Remove visited for (int i = 0; i < visited.size(); ++i) { T.erase(T.begin() + visited[i]); } visited.clear(); } if (T.size() > 0) { for (int i = 0; i < T.size(); ++i) { auto [weight, edge] = T[i]; auto [from, to] = edge; result.push_back({ from, to }); } } return result; } MST::MST(const std::vector& features, const torch::Tensor& weights, const int root) : features(features), weights(weights), root(root) {} std::vector> MST::maximumSpanningTree() { auto num_features = features.size(); Graph g(num_features); // Make a complete graph for (int i = 0; i < num_features - 1; ++i) { for (int j = i + 1; j < num_features; ++j) { g.addEdge(i, j, weights[i][j].item()); } } g.kruskal_algorithm(); auto mst = g.get_mst(); return reorder(mst, root); } }