BayesNet/bayesnet/ensembles/Ensemble.h

54 lines
2.3 KiB
C
Raw Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2023-07-14 16:23:24 +00:00
#ifndef ENSEMBLE_H
#define ENSEMBLE_H
#include <torch/torch.h>
2024-03-08 21:20:54 +00:00
#include "bayesnet/utils/BayesMetrics.h"
#include "bayesnet/utils/bayesnetUtils.h"
#include "bayesnet/classifiers/Classifier.h"
2023-07-14 16:23:24 +00:00
namespace bayesnet {
class Ensemble : public Classifier {
2023-07-14 16:23:24 +00:00
public:
Ensemble(bool predict_voting = true);
virtual ~Ensemble() = default;
2023-11-08 17:45:35 +00:00
torch::Tensor predict(torch::Tensor& X) override;
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
torch::Tensor predict_proba(torch::Tensor& X) override;
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
2023-11-08 17:45:35 +00:00
float score(torch::Tensor& X, torch::Tensor& y) override;
float score(std::vector<std::vector<int>>& X, std::vector<int>& y) override;
2023-08-07 23:53:41 +00:00
int getNumberOfNodes() const override;
int getNumberOfEdges() const override;
int getNumberOfStates() const override;
2023-11-08 17:45:35 +00:00
std::vector<std::string> show() const override;
std::vector<std::string> graph(const std::string& title) const override;
std::vector<std::string> topological_order() override
2023-08-01 22:56:52 +00:00
{
2023-11-08 17:45:35 +00:00
return std::vector<std::string>();
2023-08-01 22:56:52 +00:00
}
2024-04-07 23:25:14 +00:00
std::string dump_cpt() const override
2023-08-03 18:22:33 +00:00
{
2024-04-07 23:25:14 +00:00
return "";
2023-08-03 18:22:33 +00:00
}
protected:
torch::Tensor predict_average_voting(torch::Tensor& X);
std::vector<std::vector<double>> predict_average_voting(std::vector<std::vector<int>>& X);
torch::Tensor predict_average_proba(torch::Tensor& X);
std::vector<std::vector<double>> predict_average_proba(std::vector<std::vector<int>>& X);
torch::Tensor compute_arg_max(torch::Tensor& X);
std::vector<int> compute_arg_max(std::vector<std::vector<double>>& X);
torch::Tensor voting(torch::Tensor& votes);
unsigned n_models;
std::vector<std::unique_ptr<Classifier>> models;
std::vector<double> significanceModels;
void trainModel(const torch::Tensor& weights) override;
bool predict_voting;
2023-07-14 16:23:24 +00:00
};
}
#endif