Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef MST_H
8 : #define MST_H
9 : #include <vector>
10 : #include <string>
11 : #include <torch/torch.h>
12 : namespace bayesnet {
13 : class MST {
14 : public:
15 : MST() = default;
16 : MST(const std::vector<std::string>& features, const torch::Tensor& weights, const int root);
17 : std::vector<std::pair<int, int>> maximumSpanningTree();
18 : private:
19 : torch::Tensor weights;
20 : std::vector<std::string> features;
21 : int root = 0;
22 : };
23 : class Graph {
24 : public:
25 : explicit Graph(int V);
26 : void addEdge(int u, int v, float wt);
27 : int find_set(int i);
28 : void union_set(int u, int v);
29 : void kruskal_algorithm();
30 58 : std::vector <std::pair<float, std::pair<int, int>>> get_mst() { return T; }
31 : private:
32 : int V; // number of nodes in graph
33 : std::vector <std::pair<float, std::pair<int, int>>> G; // std::vector for graph
34 : std::vector <std::pair<float, std::pair<int, int>>> T; // std::vector for mst
35 : std::vector<int> parent;
36 : };
37 : }
38 : #endif
|