9.6 KiB
9.6 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #ifndef ENSEMBLE_H 8 : #define ENSEMBLE_H 9 : #include <torch/torch.h> 10 : #include "bayesnet/utils/BayesMetrics.h" 11 : #include "bayesnet/utils/bayesnetUtils.h" 12 : #include "bayesnet/classifiers/Classifier.h" 13 : 14 : namespace bayesnet { 15 : class Ensemble : public Classifier { 16 : public: 17 : Ensemble(bool predict_voting = true); 18 56 : virtual ~Ensemble() = default; 19 : torch::Tensor predict(torch::Tensor& X) override; 20 : std::vector<int> predict(std::vector<std::vector<int>>& X) override; 21 : torch::Tensor predict_proba(torch::Tensor& X) override; 22 : std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override; 23 : float score(torch::Tensor& X, torch::Tensor& y) override; 24 : float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override; 25 : int getNumberOfNodes() const override; 26 : int getNumberOfEdges() const override; 27 : int getNumberOfStates() const override; 28 : std::vector<std::string> show() const override; 29 : std::vector<std::string> graph(const std::string& title) const override; 30 6 : std::vector<std::string> topological_order() override 31 : { 32 6 : return std::vector<std::string>(); 33 : } 34 4 : std::string dump_cpt() const override 35 : { 36 8 : return ""; 37 : } 38 : protected: 39 : torch::Tensor predict_average_voting(torch::Tensor& X); 40 : std::vector<std::vector<double>> predict_average_voting(std::vector<std::vector<int>>& X); 41 : torch::Tensor predict_average_proba(torch::Tensor& X); 42 : std::vector<std::vector<double>> predict_average_proba(std::vector<std::vector<int>>& X); 43 : torch::Tensor compute_arg_max(torch::Tensor& X); 44 : std::vector<int> compute_arg_max(std::vector<std::vector<double>>& X); 45 : torch::Tensor voting(torch::Tensor& votes); 46 : unsigned n_models; 47 : std::vector<std::unique_ptr<Classifier>> models; 48 : std::vector<double> significanceModels; 49 : void trainModel(const torch::Tensor& weights) override; 50 : bool predict_voting; 51 : }; 52 : } 53 : #endif |
![]() |
Generated by: LCOV version 2.0-1 |
</html>