Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef NETWORK_H
8 : #define NETWORK_H
9 : #include <map>
10 : #include <vector>
11 : #include "bayesnet/config.h"
12 : #include "Node.h"
13 :
14 : namespace bayesnet {
15 : class Network {
16 : public:
17 : Network();
18 : explicit Network(float);
19 : explicit Network(const Network&);
20 4024 : ~Network() = default;
21 : torch::Tensor& getSamples();
22 : float getMaxThreads() const;
23 : void addNode(const std::string&);
24 : void addEdge(const std::string&, const std::string&);
25 : std::map<std::string, std::unique_ptr<Node>>& getNodes();
26 : std::vector<std::string> getFeatures() const;
27 : int getStates() const;
28 : std::vector<std::pair<std::string, std::string>> getEdges() const;
29 : int getNumEdges() const;
30 : int getClassNumStates() const;
31 : std::string getClassName() const;
32 : /*
33 : Notice: Nodes have to be inserted in the same order as they are in the dataset, i.e., first node is first column and so on.
34 : */
35 : void fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
36 : void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
37 : void fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
38 : std::vector<int> predict(const std::vector<std::vector<int>>&); // Return mx1 std::vector of predictions
39 : torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions
40 : torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba);
41 : std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>&); // Return mxn std::vector of probabilities
42 : torch::Tensor predict_proba(const torch::Tensor&); // Return mxn tensor of probabilities
43 : double score(const std::vector<std::vector<int>>&, const std::vector<int>&);
44 : std::vector<std::string> topological_sort();
45 : std::vector<std::string> show() const;
46 : std::vector<std::string> graph(const std::string& title) const; // Returns a std::vector of std::strings representing the graph in graphviz format
47 : void initialize();
48 : std::string dump_cpt() const;
49 : inline std::string version() { return { project_version.begin(), project_version.end() }; }
50 : private:
51 : std::map<std::string, std::unique_ptr<Node>> nodes;
52 : bool fitted;
53 : float maxThreads = 0.95;
54 : int classNumStates;
55 : std::vector<std::string> features; // Including classname
56 : std::string className;
57 : double laplaceSmoothing;
58 : torch::Tensor samples; // n+1xm tensor used to fit the model
59 : bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
60 : std::vector<double> predict_sample(const std::vector<int>&);
61 : std::vector<double> predict_sample(const torch::Tensor&);
62 : std::vector<double> exactInference(std::map<std::string, int>&);
63 : double computeFactor(std::map<std::string, int>&);
64 : void completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
65 : void checkFitData(int n_features, int n_samples, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
66 : void setStates(const std::map<std::string, std::vector<int>>&);
67 : };
68 : }
69 : #endif
|