13 KiB
13 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #ifndef NETWORK_H 8 : #define NETWORK_H 9 : #include <map> 10 : #include <vector> 11 : #include "bayesnet/config.h" 12 : #include "Node.h" 13 : 14 : namespace bayesnet { 15 : class Network { 16 : public: 17 : Network(); 18 : explicit Network(float); 19 : explicit Network(const Network&); 20 1542 : ~Network() = default; 21 : torch::Tensor& getSamples(); 22 : float getMaxThreads() const; 23 : void addNode(const std::string&); 24 : void addEdge(const std::string&, const std::string&); 25 : std::map<std::string, std::unique_ptr<Node>>& getNodes(); 26 : std::vector<std::string> getFeatures() const; 27 : int getStates() const; 28 : std::vector<std::pair<std::string, std::string>> getEdges() const; 29 : int getNumEdges() const; 30 : int getClassNumStates() const; 31 : std::string getClassName() const; 32 : /* 33 : Notice: Nodes have to be inserted in the same order as they are in the dataset, i.e., first node is first column and so on. 34 : */ 35 : void fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states); 36 : void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states); 37 : void fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states); 38 : std::vector<int> predict(const std::vector<std::vector<int>>&); // Return mx1 std::vector of predictions 39 : torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions 40 : torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba); 41 : std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>&); // Return mxn std::vector of probabilities 42 : torch::Tensor predict_proba(const torch::Tensor&); // Return mxn tensor of probabilities 43 : double score(const std::vector<std::vector<int>>&, const std::vector<int>&); 44 : std::vector<std::string> topological_sort(); 45 : std::vector<std::string> show() const; 46 : std::vector<std::string> graph(const std::string& title) const; // Returns a std::vector of std::strings representing the graph in graphviz format 47 : void initialize(); 48 : std::string dump_cpt() const; 49 : inline std::string version() { return { project_version.begin(), project_version.end() }; } 50 : private: 51 : std::map<std::string, std::unique_ptr<Node>> nodes; 52 : bool fitted; 53 : float maxThreads = 0.95; 54 : int classNumStates; 55 : std::vector<std::string> features; // Including classname 56 : std::string className; 57 : double laplaceSmoothing; 58 : torch::Tensor samples; // n+1xm tensor used to fit the model 59 : bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&); 60 : std::vector<double> predict_sample(const std::vector<int>&); 61 : std::vector<double> predict_sample(const torch::Tensor&); 62 : std::vector<double> exactInference(std::map<std::string, int>&); 63 : double computeFactor(std::map<std::string, int>&); 64 : void completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights); 65 : void checkFitData(int n_features, int n_samples, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights); 66 : void setStates(const std::map<std::string, std::vector<int>>&); 67 : }; 68 : } 69 : #endif |
![]() |
Generated by: LCOV version 2.0-1 |
</html>