Files
BayesNet/html/bayesnet/utils/Mst.cc.gcov.html
2024-05-06 17:56:00 +02:00

18 KiB

<html lang="en"> <head> </head>
LCOV - code coverage report
Current view: top level - bayesnet/utils - Mst.cc (source / functions) Coverage Total Hit
Test: BayesNet Coverage Report Lines: 94.1 % 68 64
Test Date: 2024-05-06 17:54:04 Functions: 100.0 % 10 10
Legend: Lines: hit not hit

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include <sstream>
       8              : #include <vector>
       9              : #include <list>
      10              : #include "Mst.h"
      11              : /*
      12              :     Based on the code from https://www.softwaretestinghelp.com/minimum-spanning-tree-tutorial/
      13              : 
      14              : */
      15              : 
      16              : namespace bayesnet {
      17          296 :     Graph::Graph(int V) : V(V), parent(std::vector<int>(V))
      18              :     {
      19         1124 :         for (int i = 0; i < V; i++)
      20          976 :             parent[i] = i;
      21          148 :         G.clear();
      22          148 :         T.clear();
      23          148 :     }
      24         3032 :     void Graph::addEdge(int u, int v, float wt)
      25              :     {
      26         3032 :         G.push_back({ wt, { u, v } });
      27         3032 :     }
      28        14076 :     int Graph::find_set(int i)
      29              :     {
      30              :         // If i is the parent of itself
      31        14076 :         if (i == parent[i])
      32         6064 :             return i;
      33              :         else
      34              :             //else recursively find the parent of i
      35         8012 :             return find_set(parent[i]);
      36              :     }
      37          828 :     void Graph::union_set(int u, int v)
      38              :     {
      39          828 :         parent[u] = parent[v];
      40          828 :     }
      41          148 :     void Graph::kruskal_algorithm()
      42              :     {
      43              :         // sort the edges ordered on decreasing weight
      44        11864 :         stable_sort(G.begin(), G.end(), [](const auto& left, const auto& right) {return left.first > right.first;});
      45         3180 :         for (int i = 0; i < G.size(); i++) {
      46              :             int uSt, vEd;
      47         3032 :             uSt = find_set(G[i].second.first);
      48         3032 :             vEd = find_set(G[i].second.second);
      49         3032 :             if (uSt != vEd) {
      50          828 :                 T.push_back(G[i]); // add to mst std::vector
      51          828 :                 union_set(uSt, vEd);
      52              :             }
      53              :         }
      54          148 :     }
      55              : 
      56          828 :     void insertElement(std::list<int>& variables, int variable)
      57              :     {
      58          828 :         if (std::find(variables.begin(), variables.end(), variable) == variables.end()) {
      59          828 :             variables.push_front(variable);
      60              :         }
      61          828 :     }
      62              : 
      63          148 :     std::vector<std::pair<int, int>> reorder(std::vector<std::pair<float, std::pair<int, int>>> T, int root_original)
      64              :     {
      65              :         // Create the edges of a DAG from the MST
      66              :         // replacing unordered_set with list because unordered_set cannot guarantee the order of the elements inserted
      67          148 :         auto result = std::vector<std::pair<int, int>>();
      68          148 :         auto visited = std::vector<int>();
      69          148 :         auto nextVariables = std::list<int>();
      70          148 :         nextVariables.push_front(root_original);
      71         1124 :         while (nextVariables.size() > 0) {
      72          976 :             int root = nextVariables.front();
      73          976 :             nextVariables.pop_front();
      74         3464 :             for (int i = 0; i < T.size(); ++i) {
      75         2488 :                 auto [weight, edge] = T[i];
      76         2488 :                 auto [from, to] = edge;
      77         2488 :                 if (from == root || to == root) {
      78          828 :                     visited.insert(visited.begin(), i);
      79          828 :                     if (from == root) {
      80          560 :                         result.push_back({ from, to });
      81          560 :                         insertElement(nextVariables, to);
      82              :                     } else {
      83          268 :                         result.push_back({ to, from });
      84          268 :                         insertElement(nextVariables, from);
      85              :                     }
      86              :                 }
      87              :             }
      88              :             // Remove visited
      89         1804 :             for (int i = 0; i < visited.size(); ++i) {
      90          828 :                 T.erase(T.begin() + visited[i]);
      91              :             }
      92          976 :             visited.clear();
      93              :         }
      94          148 :         if (T.size() > 0) {
      95            0 :             for (int i = 0; i < T.size(); ++i) {
      96            0 :                 auto [weight, edge] = T[i];
      97            0 :                 auto [from, to] = edge;
      98            0 :                 result.push_back({ from, to });
      99              :             }
     100              :         }
     101          296 :         return result;
     102          148 :     }
     103              : 
     104          148 :     MST::MST(const std::vector<std::string>& features, const torch::Tensor& weights, const int root) : features(features), weights(weights), root(root) {}
     105          148 :     std::vector<std::pair<int, int>> MST::maximumSpanningTree()
     106              :     {
     107          148 :         auto num_features = features.size();
     108          148 :         Graph g(num_features);
     109              :         // Make a complete graph
     110          976 :         for (int i = 0; i < num_features - 1; ++i) {
     111         3860 :             for (int j = i + 1; j < num_features; ++j) {
     112         3032 :                 g.addEdge(i, j, weights[i][j].item<float>());
     113              :             }
     114              :         }
     115          148 :         g.kruskal_algorithm();
     116          148 :         auto mst = g.get_mst();
     117          296 :         return reorder(mst, root);
     118          148 :     }
     119              : 
     120              : }
        

Generated by: LCOV version 2.0-1

</html>