Files
BayesNet/docs/manual/_node_8cc_source.html

25 KiB

<html xmlns="http://www.w3.org/1999/xhtml" lang="en-US"> <head> <script type="text/javascript" src="jquery.js"></script> <script type="text/javascript" src="dynsections.js"></script> <script type="text/javascript" src="clipboard.js"></script> <script type="text/javascript" src="navtreedata.js"></script> <script type="text/javascript" src="navtree.js"></script> <script type="text/javascript" src="resize.js"></script> <script type="text/javascript" src="cookie.js"></script> <script type="text/javascript" src="search/searchdata.js"></script> <script type="text/javascript" src="search/search.js"></script> </head>
BayesNet 1.0.5
Bayesian Network Classifiers using libtorch from scratch
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ var searchBox = new SearchBox("searchBox", "search/",'.html'); /* @license-end */ </script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { codefold.init(0); }); /* @license-end */ </script> <script type="text/javascript" src="menudata.js"></script> <script type="text/javascript" src="menu.js"></script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { initMenu('',true,false,'search.php','Search',true); $(function() { init_search(); }); }); /* @license-end */ </script>
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function(){initNavTree('_node_8cc_source.html',''); initResizable(true); }); /* @license-end */ </script>
Loading...
Searching...
No Matches
Node.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include "Node.h"
8
9namespace bayesnet {
10
11 Node::Node(const std::string& name)
12 : name(name)
13 {
14 }
15 void Node::clear()
16 {
17 parents.clear();
18 children.clear();
19 cpTable = torch::Tensor();
20 dimensions.clear();
21 numStates = 0;
22 }
23 std::string Node::getName() const
24 {
25 return name;
26 }
27 void Node::addParent(Node* parent)
28 {
29 parents.push_back(parent);
30 }
31 void Node::removeParent(Node* parent)
32 {
33 parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
34 }
35 void Node::removeChild(Node* child)
36 {
37 children.erase(std::remove(children.begin(), children.end(), child), children.end());
38 }
39 void Node::addChild(Node* child)
40 {
41 children.push_back(child);
42 }
43 std::vector<Node*>& Node::getParents()
44 {
45 return parents;
46 }
47 std::vector<Node*>& Node::getChildren()
48 {
49 return children;
50 }
51 int Node::getNumStates() const
52 {
53 return numStates;
54 }
55 void Node::setNumStates(int numStates)
56 {
57 this->numStates = numStates;
58 }
59 torch::Tensor& Node::getCPT()
60 {
61 return cpTable;
62 }
63 /*
64 The MinFill criterion is a heuristic for variable elimination.
65 The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
66 This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
67 The variable with the minimum number of edges is chosen.
68 Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
69 */
70 unsigned Node::minFill()
71 {
72 std::unordered_set<std::string> neighbors;
73 for (auto child : children) {
74 neighbors.emplace(child->getName());
75 }
76 for (auto parent : parents) {
77 neighbors.emplace(parent->getName());
78 }
79 auto source = std::vector<std::string>(neighbors.begin(), neighbors.end());
80 return combinations(source).size();
81 }
82 std::vector<std::pair<std::string, std::string>> Node::combinations(const std::vector<std::string>& source)
83 {
84 std::vector<std::pair<std::string, std::string>> result;
85 for (int i = 0; i < source.size(); ++i) {
86 std::string temp = source[i];
87 for (int j = i + 1; j < source.size(); ++j) {
88 result.push_back({ temp, source[j] });
89 }
90 }
91 return result;
92 }
93 void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double laplaceSmoothing, const torch::Tensor& weights)
94 {
95 dimensions.clear();
96 // Get dimensions of the CPT
97 dimensions.push_back(numStates);
98 transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
99 // Create a tensor of zeros with the dimensions of the CPT
100 cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
101 // Fill table with counts
102 auto pos = find(features.begin(), features.end(), name);
103 if (pos == features.end()) {
104 throw std::logic_error("Feature " + name + " not found in dataset");
105 }
106 int name_index = pos - features.begin();
107 for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
108 c10::List<c10::optional<at::Tensor>> coordinates;
109 coordinates.push_back(dataset.index({ name_index, n_sample }));
110 for (auto parent : parents) {
111 pos = find(features.begin(), features.end(), parent->getName());
112 if (pos == features.end()) {
113 throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
114 }
115 int parent_index = pos - features.begin();
116 coordinates.push_back(dataset.index({ parent_index, n_sample }));
117 }
118 // Increment the count of the corresponding coordinate
119 cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
120 }
121 // Normalize the counts
122 cpTable = cpTable / cpTable.sum(0);
123 }
124 float Node::getFactorValue(std::map<std::string, int>& evidence)
125 {
126 c10::List<c10::optional<at::Tensor>> coordinates;
127 // following predetermined order of indices in the cpTable (see Node.h)
128 coordinates.push_back(at::tensor(evidence[name]));
129 transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); });
130 return cpTable.index({ coordinates }).item<float>();
131 }
132 std::vector<std::string> Node::graph(const std::string& className)
133 {
134 auto output = std::vector<std::string>();
135 auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
136 output.push_back(name + " [shape=circle" + suffix + "] \n");
137 transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
138 return output;
139 }
140}
</html>