Files
BayesNet/html/bayesnet/ensembles/Ensemble.h.gcov.html

9.5 KiB

<html lang="en"> <head> </head>
LCOV - code coverage report
Current view: top level - bayesnet/ensembles - Ensemble.h (source / functions) Coverage Total Hit
Test: coverage.info Lines: 100.0 % 5 5
Test Date: 2024-04-30 20:26:57 Functions: 100.0 % 3 3

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #ifndef ENSEMBLE_H
       8              : #define ENSEMBLE_H
       9              : #include <torch/torch.h>
      10              : #include "bayesnet/utils/BayesMetrics.h"
      11              : #include "bayesnet/utils/bayesnetUtils.h"
      12              : #include "bayesnet/classifiers/Classifier.h"
      13              : 
      14              : namespace bayesnet {
      15              :     class Ensemble : public Classifier {
      16              :     public:
      17              :         Ensemble(bool predict_voting = true);
      18           56 :         virtual ~Ensemble() = default;
      19              :         torch::Tensor predict(torch::Tensor& X) override;
      20              :         std::vector<int> predict(std::vector<std::vector<int>>& X) override;
      21              :         torch::Tensor predict_proba(torch::Tensor& X) override;
      22              :         std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
      23              :         float score(torch::Tensor& X, torch::Tensor& y) override;
      24              :         float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
      25              :         int getNumberOfNodes() const override;
      26              :         int getNumberOfEdges() const override;
      27              :         int getNumberOfStates() const override;
      28              :         std::vector<std::string> show() const override;
      29              :         std::vector<std::string> graph(const std::string& title) const override;
      30            6 :         std::vector<std::string> topological_order()  override
      31              :         {
      32            6 :             return std::vector<std::string>();
      33              :         }
      34            4 :         std::string dump_cpt() const override
      35              :         {
      36            8 :             return "";
      37              :         }
      38              :     protected:
      39              :         torch::Tensor predict_average_voting(torch::Tensor& X);
      40              :         std::vector<std::vector<double>> predict_average_voting(std::vector<std::vector<int>>& X);
      41              :         torch::Tensor predict_average_proba(torch::Tensor& X);
      42              :         std::vector<std::vector<double>> predict_average_proba(std::vector<std::vector<int>>& X);
      43              :         torch::Tensor compute_arg_max(torch::Tensor& X);
      44              :         std::vector<int> compute_arg_max(std::vector<std::vector<double>>& X);
      45              :         torch::Tensor voting(torch::Tensor& votes);
      46              :         unsigned n_models;
      47              :         std::vector<std::unique_ptr<Classifier>> models;
      48              :         std::vector<double> significanceModels;
      49              :         void trainModel(const torch::Tensor& weights) override;
      50              :         bool predict_voting;
      51              :     };
      52              : }
      53              : #endif
        

Generated by: LCOV version 2.0-1

</html>