Files
BayesNet/html/bayesnet/network/Network.h.gcov.html

13 KiB

<html lang="en"> <head> </head>
LCOV - code coverage report
Current view: top level - bayesnet/network - Network.h (source / functions) Coverage Total Hit
Test: coverage.info Lines: 100.0 % 1 1
Test Date: 2024-04-30 20:26:57 Functions: 100.0 % 1 1

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #ifndef NETWORK_H
       8              : #define NETWORK_H
       9              : #include <map>
      10              : #include <vector>
      11              : #include "bayesnet/config.h"
      12              : #include "Node.h"
      13              : 
      14              : namespace bayesnet {
      15              :     class Network {
      16              :     public:
      17              :         Network();
      18              :         explicit Network(float);
      19              :         explicit Network(const Network&);
      20         1542 :         ~Network() = default;
      21              :         torch::Tensor& getSamples();
      22              :         float getMaxThreads() const;
      23              :         void addNode(const std::string&);
      24              :         void addEdge(const std::string&, const std::string&);
      25              :         std::map<std::string, std::unique_ptr<Node>>& getNodes();
      26              :         std::vector<std::string> getFeatures() const;
      27              :         int getStates() const;
      28              :         std::vector<std::pair<std::string, std::string>> getEdges() const;
      29              :         int getNumEdges() const;
      30              :         int getClassNumStates() const;
      31              :         std::string getClassName() const;
      32              :         /*
      33              :         Notice: Nodes have to be inserted in the same order as they are in the dataset, i.e., first node is first column and so on.
      34              :         */
      35              :         void fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
      36              :         void fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
      37              :         void fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states);
      38              :         std::vector<int> predict(const std::vector<std::vector<int>>&); // Return mx1 std::vector of predictions
      39              :         torch::Tensor predict(const torch::Tensor&); // Return mx1 tensor of predictions
      40              :         torch::Tensor predict_tensor(const torch::Tensor& samples, const bool proba);
      41              :         std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>&); // Return mxn std::vector of probabilities
      42              :         torch::Tensor predict_proba(const torch::Tensor&); // Return mxn tensor of probabilities
      43              :         double score(const std::vector<std::vector<int>>&, const std::vector<int>&);
      44              :         std::vector<std::string> topological_sort();
      45              :         std::vector<std::string> show() const;
      46              :         std::vector<std::string> graph(const std::string& title) const; // Returns a std::vector of std::strings representing the graph in graphviz format
      47              :         void initialize();
      48              :         std::string dump_cpt() const;
      49              :         inline std::string version() { return  { project_version.begin(), project_version.end() }; }
      50              :     private:
      51              :         std::map<std::string, std::unique_ptr<Node>> nodes;
      52              :         bool fitted;
      53              :         float maxThreads = 0.95;
      54              :         int classNumStates;
      55              :         std::vector<std::string> features; // Including classname
      56              :         std::string className;
      57              :         double laplaceSmoothing;
      58              :         torch::Tensor samples; // n+1xm tensor used to fit the model
      59              :         bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
      60              :         std::vector<double> predict_sample(const std::vector<int>&);
      61              :         std::vector<double> predict_sample(const torch::Tensor&);
      62              :         std::vector<double> exactInference(std::map<std::string, int>&);
      63              :         double computeFactor(std::map<std::string, int>&);
      64              :         void completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
      65              :         void checkFitData(int n_features, int n_samples, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights);
      66              :         void setStates(const std::map<std::string, std::vector<int>>&);
      67              :     };
      68              : }
      69              : #endif
        

Generated by: LCOV version 2.0-1

</html>