Files
BayesNet/docs/manual/_network_8cc_source.html

73 KiB

<html xmlns="http://www.w3.org/1999/xhtml" lang="en-US"> <head> <script type="text/javascript" src="jquery.js"></script> <script type="text/javascript" src="dynsections.js"></script> <script type="text/javascript" src="clipboard.js"></script> <script type="text/javascript" src="navtreedata.js"></script> <script type="text/javascript" src="navtree.js"></script> <script type="text/javascript" src="resize.js"></script> <script type="text/javascript" src="cookie.js"></script> <script type="text/javascript" src="search/searchdata.js"></script> <script type="text/javascript" src="search/search.js"></script> </head>
BayesNet 1.0.5
Bayesian Network Classifiers using libtorch from scratch
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ var searchBox = new SearchBox("searchBox", "search/",'.html'); /* @license-end */ </script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { codefold.init(0); }); /* @license-end */ </script> <script type="text/javascript" src="menudata.js"></script> <script type="text/javascript" src="menu.js"></script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { initMenu('',true,false,'search.php','Search',true); $(function() { init_search(); }); }); /* @license-end */ </script>
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function(){initNavTree('_network_8cc_source.html',''); initResizable(true); }); /* @license-end */ </script>
Loading...
Searching...
No Matches
Network.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include <thread>
8#include <mutex>
9#include <sstream>
10#include "Network.h"
11#include "bayesnet/utils/bayesnetUtils.h"
12namespace bayesnet {
13 Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
14 {
15 }
16 Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
17 {
18
19 }
20 Network::Network(const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
21 maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
22 {
23 if (samples.defined())
24 samples = samples.clone();
25 for (const auto& node : other.nodes) {
26 nodes[node.first] = std::make_unique<Node>(*node.second);
27 }
28 }
29 void Network::initialize()
30 {
31 features.clear();
32 className = "";
33 classNumStates = 0;
34 fitted = false;
35 nodes.clear();
36 samples = torch::Tensor();
37 }
38 float Network::getMaxThreads() const
39 {
40 return maxThreads;
41 }
42 torch::Tensor& Network::getSamples()
43 {
44 return samples;
45 }
46 void Network::addNode(const std::string& name)
47 {
48 if (name == "") {
49 throw std::invalid_argument("Node name cannot be empty");
50 }
51 if (nodes.find(name) != nodes.end()) {
52 return;
53 }
54 if (find(features.begin(), features.end(), name) == features.end()) {
55 features.push_back(name);
56 }
57 nodes[name] = std::make_unique<Node>(name);
58 }
59 std::vector<std::string> Network::getFeatures() const
60 {
61 return features;
62 }
63 int Network::getClassNumStates() const
64 {
65 return classNumStates;
66 }
67 int Network::getStates() const
68 {
69 int result = 0;
70 for (auto& node : nodes) {
71 result += node.second->getNumStates();
72 }
73 return result;
74 }
75 std::string Network::getClassName() const
76 {
77 return className;
78 }
79 bool Network::isCyclic(const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack)
80 {
81 if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
82 {
83 visited.insert(nodeId);
84 recStack.insert(nodeId);
85 for (Node* child : nodes[nodeId]->getChildren()) {
86 if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
87 return true;
88 if (recStack.find(child->getName()) != recStack.end())
89 return true;
90 }
91 }
92 recStack.erase(nodeId); // remove node from recursion stack before function ends
93 return false;
94 }
95 void Network::addEdge(const std::string& parent, const std::string& child)
96 {
97 if (nodes.find(parent) == nodes.end()) {
98 throw std::invalid_argument("Parent node " + parent + " does not exist");
99 }
100 if (nodes.find(child) == nodes.end()) {
101 throw std::invalid_argument("Child node " + child + " does not exist");
102 }
103 // Temporarily add edge to check for cycles
104 nodes[parent]->addChild(nodes[child].get());
105 nodes[child]->addParent(nodes[parent].get());
106 std::unordered_set<std::string> visited;
107 std::unordered_set<std::string> recStack;
108 if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
109 {
110 // remove problematic edge
111 nodes[parent]->removeChild(nodes[child].get());
112 nodes[child]->removeParent(nodes[parent].get());
113 throw std::invalid_argument("Adding this edge forms a cycle in the graph.");
114 }
115 }
116 std::map<std::string, std::unique_ptr<Node>>& Network::getNodes()
117 {
118 return nodes;
119 }
120 void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
121 {
122 if (weights.size(0) != n_samples) {
123 throw std::invalid_argument("Weights (" + std::to_string(weights.size(0)) + ") must have the same number of elements as samples (" + std::to_string(n_samples) + ") in Network::fit");
124 }
125 if (n_samples != n_samples_y) {
126 throw std::invalid_argument("X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) + " != " + std::to_string(n_samples_y) + ")");
127 }
128 if (n_features != featureNames.size()) {
129 throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
130 }
131 if (features.size() == 0) {
132 throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()");
133 }
134 if (n_features != features.size() - 1) {
135 throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
136 }
137 if (find(features.begin(), features.end(), className) == features.end()) {
138 throw std::invalid_argument("Class Name not found in Network::features");
139 }
140 for (auto& feature : featureNames) {
141 if (find(features.begin(), features.end(), feature) == features.end()) {
142 throw std::invalid_argument("Feature " + feature + " not found in Network::features");
143 }
144 if (states.find(feature) == states.end()) {
145 throw std::invalid_argument("Feature " + feature + " not found in states");
146 }
147 }
148 }
149 void Network::setStates(const std::map<std::string, std::vector<int>>& states)
150 {
151 // Set states to every Node in the network
152 for_each(features.begin(), features.end(), [this, &states](const std::string& feature) {
153 nodes.at(feature)->setNumStates(states.at(feature).size());
154 });
155 classNumStates = nodes.at(className)->getNumStates();
156 }
157 // X comes in nxm, where n is the number of features and m the number of samples
158 void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
159 {
160 checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
161 this->className = className;
162 torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
163 samples = torch::cat({ X , ytmp }, 0);
164 for (int i = 0; i < featureNames.size(); ++i) {
165 auto row_feature = X.index({ i, "..." });
166 }
167 completeFit(states, weights);
168 }
169 void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
170 {
171 checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
172 this->className = className;
173 this->samples = samples;
174 completeFit(states, weights);
175 }
176 // input_data comes in nxm, where n is the number of features and m the number of samples
177 void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
178 {
179 const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
180 checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
181 this->className = className;
182 // Build tensor of samples (nxm) (n+1 because of the class)
183 samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32);
184 for (int i = 0; i < featureNames.size(); ++i) {
185 samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
186 }
187 samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
188 completeFit(states, weights);
189 }
190 void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
191 {
192 setStates(states);
193 laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
194 std::vector<std::thread> threads;
195 for (auto& node : nodes) {
196 threads.emplace_back([this, &node, &weights]() {
197 node.second->computeCPT(samples, features, laplaceSmoothing, weights);
198 });
199 }
200 for (auto& thread : threads) {
201 thread.join();
202 }
203 fitted = true;
204 }
205 torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
206 {
207 if (!fitted) {
208 throw std::logic_error("You must call fit() before calling predict()");
209 }
210 torch::Tensor result;
211 result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
212 for (int i = 0; i < samples.size(1); ++i) {
213 const torch::Tensor sample = samples.index({ "...", i });
214 auto psample = predict_sample(sample);
215 auto temp = torch::tensor(psample, torch::kFloat64);
216 // result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64));
217 result.index_put_({ i, "..." }, temp);
218 }
219 if (proba)
220 return result;
221 return result.argmax(1);
222 }
223 // Return mxn tensor of probabilities
224 torch::Tensor Network::predict_proba(const torch::Tensor& samples)
225 {
226 return predict_tensor(samples, true);
227 }
228
229 // Return mxn tensor of probabilities
230 torch::Tensor Network::predict(const torch::Tensor& samples)
231 {
232 return predict_tensor(samples, false);
233 }
234
235 // Return mx1 std::vector of predictions
236 // tsamples is nxm std::vector of samples
237 std::vector<int> Network::predict(const std::vector<std::vector<int>>& tsamples)
238 {
239 if (!fitted) {
240 throw std::logic_error("You must call fit() before calling predict()");
241 }
242 std::vector<int> predictions;
243 std::vector<int> sample;
244 for (int row = 0; row < tsamples[0].size(); ++row) {
245 sample.clear();
246 for (int col = 0; col < tsamples.size(); ++col) {
247 sample.push_back(tsamples[col][row]);
248 }
249 std::vector<double> classProbabilities = predict_sample(sample);
250 // Find the class with the maximum posterior probability
251 auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
252 int predictedClass = distance(classProbabilities.begin(), maxElem);
253 predictions.push_back(predictedClass);
254 }
255 return predictions;
256 }
257 // Return mxn std::vector of probabilities
258 // tsamples is nxm std::vector of samples
259 std::vector<std::vector<double>> Network::predict_proba(const std::vector<std::vector<int>>& tsamples)
260 {
261 if (!fitted) {
262 throw std::logic_error("You must call fit() before calling predict_proba()");
263 }
264 std::vector<std::vector<double>> predictions;
265 std::vector<int> sample;
266 for (int row = 0; row < tsamples[0].size(); ++row) {
267 sample.clear();
268 for (int col = 0; col < tsamples.size(); ++col) {
269 sample.push_back(tsamples[col][row]);
270 }
271 predictions.push_back(predict_sample(sample));
272 }
273 return predictions;
274 }
275 double Network::score(const std::vector<std::vector<int>>& tsamples, const std::vector<int>& labels)
276 {
277 std::vector<int> y_pred = predict(tsamples);
278 int correct = 0;
279 for (int i = 0; i < y_pred.size(); ++i) {
280 if (y_pred[i] == labels[i]) {
281 correct++;
282 }
283 }
284 return (double)correct / y_pred.size();
285 }
286 // Return 1xn std::vector of probabilities
287 std::vector<double> Network::predict_sample(const std::vector<int>& sample)
288 {
289 // Ensure the sample size is equal to the number of features
290 if (sample.size() != features.size() - 1) {
291 throw std::invalid_argument("Sample size (" + std::to_string(sample.size()) +
292 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
293 }
294 std::map<std::string, int> evidence;
295 for (int i = 0; i < sample.size(); ++i) {
296 evidence[features[i]] = sample[i];
297 }
298 return exactInference(evidence);
299 }
300 // Return 1xn std::vector of probabilities
301 std::vector<double> Network::predict_sample(const torch::Tensor& sample)
302 {
303 // Ensure the sample size is equal to the number of features
304 if (sample.size(0) != features.size() - 1) {
305 throw std::invalid_argument("Sample size (" + std::to_string(sample.size(0)) +
306 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
307 }
308 std::map<std::string, int> evidence;
309 for (int i = 0; i < sample.size(0); ++i) {
310 evidence[features[i]] = sample[i].item<int>();
311 }
312 return exactInference(evidence);
313 }
314 double Network::computeFactor(std::map<std::string, int>& completeEvidence)
315 {
316 double result = 1.0;
317 for (auto& node : getNodes()) {
318 result *= node.second->getFactorValue(completeEvidence);
319 }
320 return result;
321 }
322 std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
323 {
324 std::vector<double> result(classNumStates, 0.0);
325 std::vector<std::thread> threads;
326 std::mutex mtx;
327 for (int i = 0; i < classNumStates; ++i) {
328 threads.emplace_back([this, &result, &evidence, i, &mtx]() {
329 auto completeEvidence = std::map<std::string, int>(evidence);
330 completeEvidence[getClassName()] = i;
331 double factor = computeFactor(completeEvidence);
332 std::lock_guard<std::mutex> lock(mtx);
333 result[i] = factor;
334 });
335 }
336 for (auto& thread : threads) {
337 thread.join();
338 }
339 // Normalize result
340 double sum = accumulate(result.begin(), result.end(), 0.0);
341 transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; });
342 return result;
343 }
344 std::vector<std::string> Network::show() const
345 {
346 std::vector<std::string> result;
347 // Draw the network
348 for (auto& node : nodes) {
349 std::string line = node.first + " -> ";
350 for (auto child : node.second->getChildren()) {
351 line += child->getName() + ", ";
352 }
353 result.push_back(line);
354 }
355 return result;
356 }
357 std::vector<std::string> Network::graph(const std::string& title) const
358 {
359 auto output = std::vector<std::string>();
360 auto prefix = "digraph BayesNet {\nlabel=<BayesNet ";
361 auto suffix = ">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
362 std::string header = prefix + title + suffix;
363 output.push_back(header);
364 for (auto& node : nodes) {
365 auto result = node.second->graph(className);
366 output.insert(output.end(), result.begin(), result.end());
367 }
368 output.push_back("}\n");
369 return output;
370 }
371 std::vector<std::pair<std::string, std::string>> Network::getEdges() const
372 {
373 auto edges = std::vector<std::pair<std::string, std::string>>();
374 for (const auto& node : nodes) {
375 auto head = node.first;
376 for (const auto& child : node.second->getChildren()) {
377 auto tail = child->getName();
378 edges.push_back({ head, tail });
379 }
380 }
381 return edges;
382 }
383 int Network::getNumEdges() const
384 {
385 return getEdges().size();
386 }
387 std::vector<std::string> Network::topological_sort()
388 {
389 /* Check if al the fathers of every node are before the node */
390 auto result = features;
391 result.erase(remove(result.begin(), result.end(), className), result.end());
392 bool ending{ false };
393 while (!ending) {
394 ending = true;
395 for (auto feature : features) {
396 auto fathers = nodes[feature]->getParents();
397 for (const auto& father : fathers) {
398 auto fatherName = father->getName();
399 if (fatherName == className) {
400 continue;
401 }
402 // Check if father is placed before the actual feature
403 auto it = find(result.begin(), result.end(), fatherName);
404 if (it != result.end()) {
405 auto it2 = find(result.begin(), result.end(), feature);
406 if (it2 != result.end()) {
407 if (distance(it, it2) < 0) {
408 // if it is not, insert it before the feature
409 result.erase(remove(result.begin(), result.end(), fatherName), result.end());
410 result.insert(it2, fatherName);
411 ending = false;
412 }
413 }
414 }
415 }
416 }
417 }
418 return result;
419 }
420 std::string Network::dump_cpt() const
421 {
422 std::stringstream oss;
423 for (auto& node : nodes) {
424 oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
425 oss << node.second->getCPT() << std::endl;
426 }
427 return oss.str();
428 }
429}
</html>