Files
BayesNet/docs/manual/_network_8h_source.html

19 KiB

<html xmlns="http://www.w3.org/1999/xhtml" lang="en-US"> <head> <script type="text/javascript" src="jquery.js"></script> <script type="text/javascript" src="dynsections.js"></script> <script type="text/javascript" src="clipboard.js"></script> <script type="text/javascript" src="navtreedata.js"></script> <script type="text/javascript" src="navtree.js"></script> <script type="text/javascript" src="resize.js"></script> <script type="text/javascript" src="cookie.js"></script> <script type="text/javascript" src="search/searchdata.js"></script> <script type="text/javascript" src="search/search.js"></script> </head>
BayesNet 1.0.5
Bayesian Network Classifiers using libtorch from scratch
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ var searchBox = new SearchBox("searchBox", "search/",'.html'); /* @license-end */ </script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { codefold.init(0); }); /* @license-end */ </script> <script type="text/javascript" src="menudata.js"></script> <script type="text/javascript" src="menu.js"></script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { initMenu('',true,false,'search.php','Search',true); $(function() { init_search(); }); }); /* @license-end */ </script>
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function(){initNavTree('_network_8h_source.html',''); initResizable(true); }); /* @license-end */ </script>
Loading...
Searching...
No Matches
Network.h
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#ifndef NETWORK_H
8#define NETWORK_H
9#include <map>
10#include <vector>
11#include "bayesnet/config.h"
12#include "Node.h"
13
14namespace bayesnet {
15 class Network {
16 public:
17 Network();
18 explicit Network(float);
19 explicit Network(const Network&);
20 ~Network() = default;
21 torch::Tensor& getSamples();
22 float getMaxThreads() const;
23 void addNode(const std::string&);
24 void addEdge(const std::string&, const std::string&);
25 std::map<std::string, std::unique_ptr<Node>>& getNodes();
26 std::vector<std::string> getFeatures() const;
27 int getStates() const;
28 std::vector<std::pair<std::string, std::string>> getEdges() const;
29 int getNumEdges() const;
30 int getClassNumStates() const;
31 std::string getClassName() const;
32 /*
33 Notice: Nodes have to be inserted in the same order as they are in the dataset, i.e., first node is first column and so on.
34 */
35 void fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
36 void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
37 void fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
38 std::vector<int> predict(const std::vector<std::vector<int>>&); // Return mx1 std::vector of predictions
39 torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions
40 torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba);
41 std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>&); // Return mxn std::vector of probabilities
42 torch::Tensor predict_proba(const torch::Tensor&); // Return mxn tensor of probabilities
43 double score(const std::vector<std::vector<int>>&, const std::vector<int>&);
44 std::vector<std::string> topological_sort();
45 std::vector<std::string> show() const;
46 std::vector<std::string> graph(const std::string& title) const; // Returns a std::vector of std::strings representing the graph in graphviz format
47 void initialize();
48 std::string dump_cpt() const;
49 inline std::string version() { return { project_version.begin(), project_version.end() }; }
50 private:
51 std::map<std::string, std::unique_ptr<Node>> nodes;
52 bool fitted;
53 float maxThreads = 0.95;
54 int classNumStates;
55 std::vector<std::string> features; // Including classname
56 std::string className;
57 double laplaceSmoothing;
58 torch::Tensor samples; // n+1xm tensor used to fit the model
59 bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
60 std::vector<double> predict_sample(const std::vector<int>&);
61 std::vector<double> predict_sample(const torch::Tensor&);
62 std::vector<double> exactInference(std::map<std::string, int>&);
63 double computeFactor(std::map<std::string, int>&);
64 void completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
65 void checkFitData(int n_features, int n_samples, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
66 void setStates(const std::map<std::string, std::vector<int>>&);
67 };
68}
69#endif
</html>