Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "Node.h"
8 :
9 : namespace bayesnet {
10 :
11 116977 : Node::Node(const std::string& name)
12 116977 : : name(name), numStates(0), cpTable(torch::Tensor()), parents(std::vector<Node*>()), children(std::vector<Node*>())
13 : {
14 116977 : }
15 11 : void Node::clear()
16 : {
17 11 : parents.clear();
18 11 : children.clear();
19 11 : cpTable = torch::Tensor();
20 11 : dimensions.clear();
21 11 : numStates = 0;
22 11 : }
23 159096442 : std::string Node::getName() const
24 : {
25 159096442 : return name;
26 : }
27 224331 : void Node::addParent(Node* parent)
28 : {
29 224331 : parents.push_back(parent);
30 224331 : }
31 33 : void Node::removeParent(Node* parent)
32 : {
33 33 : parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
34 33 : }
35 33 : void Node::removeChild(Node* child)
36 : {
37 33 : children.erase(std::remove(children.begin(), children.end(), child), children.end());
38 33 : }
39 224353 : void Node::addChild(Node* child)
40 : {
41 224353 : children.push_back(child);
42 224353 : }
43 13948 : std::vector<Node*>& Node::getParents()
44 : {
45 13948 : return parents;
46 : }
47 306501 : std::vector<Node*>& Node::getChildren()
48 : {
49 306501 : return children;
50 : }
51 236412 : int Node::getNumStates() const
52 : {
53 236412 : return numStates;
54 : }
55 121388 : void Node::setNumStates(int numStates)
56 : {
57 121388 : this->numStates = numStates;
58 121388 : }
59 1155 : torch::Tensor& Node::getCPT()
60 : {
61 1155 : return cpTable;
62 : }
63 : /*
64 : The MinFill criterion is a heuristic for variable elimination.
65 : The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
66 : This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
67 : The variable with the minimum number of edges is chosen.
68 : Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
69 : */
70 55 : unsigned Node::minFill()
71 : {
72 55 : std::unordered_set<std::string> neighbors;
73 143 : for (auto child : children) {
74 88 : neighbors.emplace(child->getName());
75 : }
76 132 : for (auto parent : parents) {
77 77 : neighbors.emplace(parent->getName());
78 : }
79 55 : auto source = std::vector<std::string>(neighbors.begin(), neighbors.end());
80 110 : return combinations(source).size();
81 55 : }
82 55 : std::vector<std::pair<std::string, std::string>> Node::combinations(const std::vector<std::string>& source)
83 : {
84 55 : std::vector<std::pair<std::string, std::string>> result;
85 220 : for (int i = 0; i < source.size(); ++i) {
86 165 : std::string temp = source[i];
87 341 : for (int j = i + 1; j < source.size(); ++j) {
88 176 : result.push_back({ temp, source[j] });
89 : }
90 165 : }
91 55 : return result;
92 0 : }
93 121388 : void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double laplaceSmoothing, const torch::Tensor& weights)
94 : {
95 121388 : dimensions.clear();
96 : // Get dimensions of the CPT
97 121388 : dimensions.push_back(numStates);
98 353397 : transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
99 :
100 : // Create a tensor of zeros with the dimensions of the CPT
101 121388 : cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
102 : // Fill table with counts
103 121388 : auto pos = find(features.begin(), features.end(), name);
104 121388 : if (pos == features.end()) {
105 0 : throw std::logic_error("Feature " + name + " not found in dataset");
106 : }
107 121388 : int name_index = pos - features.begin();
108 21284350 : for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
109 21162962 : c10::List<c10::optional<at::Tensor>> coordinates;
110 63488886 : coordinates.push_back(dataset.index({ name_index, n_sample }));
111 60665104 : for (auto parent : parents) {
112 39502142 : pos = find(features.begin(), features.end(), parent->getName());
113 39502142 : if (pos == features.end()) {
114 0 : throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
115 : }
116 39502142 : int parent_index = pos - features.begin();
117 118506426 : coordinates.push_back(dataset.index({ parent_index, n_sample }));
118 : }
119 : // Increment the count of the corresponding coordinate
120 42325924 : cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
121 21162962 : }
122 : // Normalize the counts
123 121388 : cpTable = cpTable / cpTable.sum(0);
124 81949454 : }
125 67302838 : float Node::getFactorValue(std::map<std::string, int>& evidence)
126 : {
127 67302838 : c10::List<c10::optional<at::Tensor>> coordinates;
128 : // following predetermined order of indices in the cpTable (see Node.h)
129 67302838 : coordinates.push_back(at::tensor(evidence[name]));
130 186413412 : transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); });
131 134605676 : return cpTable.index({ coordinates }).item<float>();
132 67302838 : }
133 1683 : std::vector<std::string> Node::graph(const std::string& className)
134 : {
135 1683 : auto output = std::vector<std::string>();
136 1683 : auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
137 1683 : output.push_back(name + " [shape=circle" + suffix + "] \n");
138 4334 : transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
139 1683 : return output;
140 0 : }
141 : }
|