// *************************************************************** // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** #ifndef ENSEMBLE_H #define ENSEMBLE_H #include #include "bayesnet/utils/BayesMetrics.h" #include "bayesnet/utils/bayesnetUtils.h" #include "bayesnet/classifiers/Classifier.h" namespace bayesnet { class Ensemble : public Classifier { public: Ensemble(bool predict_voting = true); virtual ~Ensemble() = default; torch::Tensor predict(torch::Tensor& X) override; std::vector predict(std::vector>& X) override; torch::Tensor predict_proba(torch::Tensor& X) override; std::vector> predict_proba(std::vector>& X) override; float score(torch::Tensor& X, torch::Tensor& y) override; float score(std::vector>& X, std::vector& y) override; int getNumberOfNodes() const override; int getNumberOfEdges() const override; int getNumberOfStates() const override; std::vector show() const override; std::vector graph(const std::string& title) const override; std::vector topological_order() override { return std::vector(); } std::string dump_cpt() const override { return ""; } protected: torch::Tensor predict_average_voting(torch::Tensor& X); std::vector> predict_average_voting(std::vector>& X); torch::Tensor predict_average_proba(torch::Tensor& X); std::vector> predict_average_proba(std::vector>& X); torch::Tensor compute_arg_max(torch::Tensor& X); std::vector compute_arg_max(std::vector>& X); torch::Tensor voting(torch::Tensor& votes); unsigned n_models; std::vector> models; std::vector significanceModels; void trainModel(const torch::Tensor& weights) override; bool predict_voting; }; } #endif