2024-04-11 16:02:49 +00:00
|
|
|
// ***************************************************************
|
|
|
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
|
|
|
// SPDX-FileType: SOURCE
|
|
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
// ***************************************************************
|
|
|
|
|
2024-04-11 15:29:46 +00:00
|
|
|
#include <sstream>
|
2023-07-15 16:31:50 +00:00
|
|
|
#include <vector>
|
2023-10-07 17:08:13 +00:00
|
|
|
#include <list>
|
2024-03-08 21:20:54 +00:00
|
|
|
#include "Mst.h"
|
2023-07-15 16:31:50 +00:00
|
|
|
/*
|
|
|
|
Based on the code from https://www.softwaretestinghelp.com/minimum-spanning-tree-tutorial/
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
namespace bayesnet {
|
2023-11-08 17:45:35 +00:00
|
|
|
Graph::Graph(int V) : V(V), parent(std::vector<int>(V))
|
2023-07-15 16:31:50 +00:00
|
|
|
{
|
|
|
|
for (int i = 0; i < V; i++)
|
|
|
|
parent[i] = i;
|
|
|
|
G.clear();
|
|
|
|
T.clear();
|
|
|
|
}
|
|
|
|
void Graph::addEdge(int u, int v, float wt)
|
|
|
|
{
|
|
|
|
G.push_back({ wt, { u, v } });
|
|
|
|
}
|
|
|
|
int Graph::find_set(int i)
|
|
|
|
{
|
|
|
|
// If i is the parent of itself
|
|
|
|
if (i == parent[i])
|
|
|
|
return i;
|
|
|
|
else
|
|
|
|
//else recursively find the parent of i
|
|
|
|
return find_set(parent[i]);
|
|
|
|
}
|
|
|
|
void Graph::union_set(int u, int v)
|
|
|
|
{
|
|
|
|
parent[u] = parent[v];
|
|
|
|
}
|
|
|
|
void Graph::kruskal_algorithm()
|
|
|
|
{
|
|
|
|
// sort the edges ordered on decreasing weight
|
2023-10-06 15:08:54 +00:00
|
|
|
stable_sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;});
|
2023-07-29 22:04:18 +00:00
|
|
|
for (int i = 0; i < G.size(); i++) {
|
|
|
|
int uSt, vEd;
|
2023-07-15 16:31:50 +00:00
|
|
|
uSt = find_set(G[i].second.first);
|
|
|
|
vEd = find_set(G[i].second.second);
|
|
|
|
if (uSt != vEd) {
|
2023-11-08 17:45:35 +00:00
|
|
|
T.push_back(G[i]); // add to mst std::vector
|
2023-07-15 16:31:50 +00:00
|
|
|
union_set(uSt, vEd);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-11-08 17:45:35 +00:00
|
|
|
void insertElement(std::list<int>& variables, int variable)
|
2023-10-07 17:08:13 +00:00
|
|
|
{
|
2023-11-08 17:45:35 +00:00
|
|
|
if (std::find(variables.begin(), variables.end(), variable) == variables.end()) {
|
2023-10-07 17:08:13 +00:00
|
|
|
variables.push_front(variable);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<std::pair<int, int>> reorder(std::vector<std::pair<float, std::pair<int, int>>> T, int root_original)
|
2023-07-15 16:31:50 +00:00
|
|
|
{
|
2023-10-07 17:08:13 +00:00
|
|
|
// Create the edges of a DAG from the MST
|
|
|
|
// replacing unordered_set with list because unordered_set cannot guarantee the order of the elements inserted
|
2023-11-08 17:45:35 +00:00
|
|
|
auto result = std::vector<std::pair<int, int>>();
|
|
|
|
auto visited = std::vector<int>();
|
|
|
|
auto nextVariables = std::list<int>();
|
2023-10-07 17:08:13 +00:00
|
|
|
nextVariables.push_front(root_original);
|
2023-07-15 16:31:50 +00:00
|
|
|
while (nextVariables.size() > 0) {
|
2023-10-07 17:08:13 +00:00
|
|
|
int root = nextVariables.front();
|
|
|
|
nextVariables.pop_front();
|
2023-07-15 16:31:50 +00:00
|
|
|
for (int i = 0; i < T.size(); ++i) {
|
|
|
|
auto [weight, edge] = T[i];
|
|
|
|
auto [from, to] = edge;
|
|
|
|
if (from == root || to == root) {
|
|
|
|
visited.insert(visited.begin(), i);
|
|
|
|
if (from == root) {
|
|
|
|
result.push_back({ from, to });
|
2023-10-07 17:08:13 +00:00
|
|
|
insertElement(nextVariables, to);
|
2023-07-15 16:31:50 +00:00
|
|
|
} else {
|
|
|
|
result.push_back({ to, from });
|
2023-10-07 17:08:13 +00:00
|
|
|
insertElement(nextVariables, from);
|
2023-07-15 16:31:50 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Remove visited
|
|
|
|
for (int i = 0; i < visited.size(); ++i) {
|
|
|
|
T.erase(T.begin() + visited[i]);
|
|
|
|
}
|
|
|
|
visited.clear();
|
|
|
|
}
|
|
|
|
if (T.size() > 0) {
|
|
|
|
for (int i = 0; i < T.size(); ++i) {
|
|
|
|
auto [weight, edge] = T[i];
|
|
|
|
auto [from, to] = edge;
|
|
|
|
result.push_back({ from, to });
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2023-11-08 17:45:35 +00:00
|
|
|
MST::MST(const std::vector<std::string>& features, const torch::Tensor& weights, const int root) : features(features), weights(weights), root(root) {}
|
|
|
|
std::vector<std::pair<int, int>> MST::maximumSpanningTree()
|
2023-07-15 16:31:50 +00:00
|
|
|
{
|
|
|
|
auto num_features = features.size();
|
|
|
|
Graph g(num_features);
|
|
|
|
// Make a complete graph
|
|
|
|
for (int i = 0; i < num_features - 1; ++i) {
|
2023-07-27 14:51:27 +00:00
|
|
|
for (int j = i + 1; j < num_features; ++j) {
|
2023-07-15 16:31:50 +00:00
|
|
|
g.addEdge(i, j, weights[i][j].item<float>());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
g.kruskal_algorithm();
|
|
|
|
auto mst = g.get_mst();
|
|
|
|
return reorder(mst, root);
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|