18 KiB
18 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | |||||||||||||||||||||||||
![]() | |||||||||||||||||||||||||
|
|||||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #include <sstream> 8 : #include <vector> 9 : #include <list> 10 : #include "Mst.h" 11 : /* 12 : Based on the code from https://www.softwaretestinghelp.com/minimum-spanning-tree-tutorial/ 13 : 14 : */ 15 : 16 : namespace bayesnet { 17 296 : Graph::Graph(int V) : V(V), parent(std::vector<int>(V)) 18 : { 19 1124 : for (int i = 0; i < V; i++) 20 976 : parent[i] = i; 21 148 : G.clear(); 22 148 : T.clear(); 23 148 : } 24 3032 : void Graph::addEdge(int u, int v, float wt) 25 : { 26 3032 : G.push_back({ wt, { u, v } }); 27 3032 : } 28 14076 : int Graph::find_set(int i) 29 : { 30 : // If i is the parent of itself 31 14076 : if (i == parent[i]) 32 6064 : return i; 33 : else 34 : //else recursively find the parent of i 35 8012 : return find_set(parent[i]); 36 : } 37 828 : void Graph::union_set(int u, int v) 38 : { 39 828 : parent[u] = parent[v]; 40 828 : } 41 148 : void Graph::kruskal_algorithm() 42 : { 43 : // sort the edges ordered on decreasing weight 44 11864 : stable_sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;}); 45 3180 : for (int i = 0; i < G.size(); i++) { 46 : int uSt, vEd; 47 3032 : uSt = find_set(G[i].second.first); 48 3032 : vEd = find_set(G[i].second.second); 49 3032 : if (uSt != vEd) { 50 828 : T.push_back(G[i]); // add to mst std::vector 51 828 : union_set(uSt, vEd); 52 : } 53 : } 54 148 : } 55 : 56 828 : void insertElement(std::list<int>& variables, int variable) 57 : { 58 828 : if (std::find(variables.begin(), variables.end(), variable) == variables.end()) { 59 828 : variables.push_front(variable); 60 : } 61 828 : } 62 : 63 148 : std::vector<std::pair<int, int>> reorder(std::vector<std::pair<float, std::pair<int, int>>> T, int root_original) 64 : { 65 : // Create the edges of a DAG from the MST 66 : // replacing unordered_set with list because unordered_set cannot guarantee the order of the elements inserted 67 148 : auto result = std::vector<std::pair<int, int>>(); 68 148 : auto visited = std::vector<int>(); 69 148 : auto nextVariables = std::list<int>(); 70 148 : nextVariables.push_front(root_original); 71 1124 : while (nextVariables.size() > 0) { 72 976 : int root = nextVariables.front(); 73 976 : nextVariables.pop_front(); 74 3464 : for (int i = 0; i < T.size(); ++i) { 75 2488 : auto [weight, edge] = T[i]; 76 2488 : auto [from, to] = edge; 77 2488 : if (from == root || to == root) { 78 828 : visited.insert(visited.begin(), i); 79 828 : if (from == root) { 80 560 : result.push_back({ from, to }); 81 560 : insertElement(nextVariables, to); 82 : } else { 83 268 : result.push_back({ to, from }); 84 268 : insertElement(nextVariables, from); 85 : } 86 : } 87 : } 88 : // Remove visited 89 1804 : for (int i = 0; i < visited.size(); ++i) { 90 828 : T.erase(T.begin() + visited[i]); 91 : } 92 976 : visited.clear(); 93 : } 94 148 : if (T.size() > 0) { 95 0 : for (int i = 0; i < T.size(); ++i) { 96 0 : auto [weight, edge] = T[i]; 97 0 : auto [from, to] = edge; 98 0 : result.push_back({ from, to }); 99 : } 100 : } 101 296 : return result; 102 148 : } 103 : 104 148 : MST::MST(const std::vector<std::string>& features, const torch::Tensor& weights, const int root) : features(features), weights(weights), root(root) {} 105 148 : std::vector<std::pair<int, int>> MST::maximumSpanningTree() 106 : { 107 148 : auto num_features = features.size(); 108 148 : Graph g(num_features); 109 : // Make a complete graph 110 976 : for (int i = 0; i < num_features - 1; ++i) { 111 3860 : for (int j = i + 1; j < num_features; ++j) { 112 3032 : g.addEdge(i, j, weights[i][j].item<float>()); 113 : } 114 : } 115 148 : g.kruskal_algorithm(); 116 148 : auto mst = g.get_mst(); 117 296 : return reorder(mst, root); 118 148 : } 119 : 120 : } |
![]() |
Generated by: LCOV version 2.0-1 |
</html>