Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include <sstream>
8 : #include <vector>
9 : #include <list>
10 : #include "Mst.h"
11 : /*
12 : Based on the code from https://www.softwaretestinghelp.com/minimum-spanning-tree-tutorial/
13 :
14 : */
15 :
16 : namespace bayesnet {
17 638 : Graph::Graph(int V) : V(V), parent(std::vector<int>(V))
18 : {
19 2387 : for (int i = 0; i < V; i++)
20 2068 : parent[i] = i;
21 319 : G.clear();
22 319 : T.clear();
23 319 : }
24 6336 : void Graph::addEdge(int u, int v, float wt)
25 : {
26 6336 : G.push_back({ wt, { u, v } });
27 6336 : }
28 29293 : int Graph::find_set(int i)
29 : {
30 : // If i is the parent of itself
31 29293 : if (i == parent[i])
32 12672 : return i;
33 : else
34 : //else recursively find the parent of i
35 16621 : return find_set(parent[i]);
36 : }
37 1749 : void Graph::union_set(int u, int v)
38 : {
39 1749 : parent[u] = parent[v];
40 1749 : }
41 319 : void Graph::kruskal_algorithm()
42 : {
43 : // sort the edges ordered on decreasing weight
44 24662 : stable_sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;});
45 6655 : for (int i = 0; i < G.size(); i++) {
46 : int uSt, vEd;
47 6336 : uSt = find_set(G[i].second.first);
48 6336 : vEd = find_set(G[i].second.second);
49 6336 : if (uSt != vEd) {
50 1749 : T.push_back(G[i]); // add to mst std::vector
51 1749 : union_set(uSt, vEd);
52 : }
53 : }
54 319 : }
55 :
56 1749 : void insertElement(std::list<int>& variables, int variable)
57 : {
58 1749 : if (std::find(variables.begin(), variables.end(), variable) == variables.end()) {
59 1749 : variables.push_front(variable);
60 : }
61 1749 : }
62 :
63 319 : std::vector<std::pair<int, int>> reorder(std::vector<std::pair<float, std::pair<int, int>>> T, int root_original)
64 : {
65 : // Create the edges of a DAG from the MST
66 : // replacing unordered_set with list because unordered_set cannot guarantee the order of the elements inserted
67 319 : auto result = std::vector<std::pair<int, int>>();
68 319 : auto visited = std::vector<int>();
69 319 : auto nextVariables = std::list<int>();
70 319 : nextVariables.push_front(root_original);
71 2387 : while (nextVariables.size() > 0) {
72 2068 : int root = nextVariables.front();
73 2068 : nextVariables.pop_front();
74 7304 : for (int i = 0; i < T.size(); ++i) {
75 5236 : auto [weight, edge] = T[i];
76 5236 : auto [from, to] = edge;
77 5236 : if (from == root || to == root) {
78 1749 : visited.insert(visited.begin(), i);
79 1749 : if (from == root) {
80 1166 : result.push_back({ from, to });
81 1166 : insertElement(nextVariables, to);
82 : } else {
83 583 : result.push_back({ to, from });
84 583 : insertElement(nextVariables, from);
85 : }
86 : }
87 : }
88 : // Remove visited
89 3817 : for (int i = 0; i < visited.size(); ++i) {
90 1749 : T.erase(T.begin() + visited[i]);
91 : }
92 2068 : visited.clear();
93 : }
94 319 : if (T.size() > 0) {
95 0 : for (int i = 0; i < T.size(); ++i) {
96 0 : auto [weight, edge] = T[i];
97 0 : auto [from, to] = edge;
98 0 : result.push_back({ from, to });
99 : }
100 : }
101 638 : return result;
102 319 : }
103 :
104 319 : MST::MST(const std::vector<std::string>& features, const torch::Tensor& weights, const int root) : features(features), weights(weights), root(root) {}
105 319 : std::vector<std::pair<int, int>> MST::maximumSpanningTree()
106 : {
107 319 : auto num_features = features.size();
108 319 : Graph g(num_features);
109 : // Make a complete graph
110 2068 : for (int i = 0; i < num_features - 1; ++i) {
111 8085 : for (int j = i + 1; j < num_features; ++j) {
112 6336 : g.addEdge(i, j, weights[i][j].item<float>());
113 : }
114 : }
115 319 : g.kruskal_algorithm();
116 319 : auto mst = g.get_mst();
117 638 : return reorder(mst, root);
118 319 : }
119 :
120 : }
|