#ifndef ENSEMBLE_H #define ENSEMBLE_H #include #include "BaseClassifier.h" #include "Metrics.hpp" #include "utils.h" using namespace std; using namespace torch; namespace bayesnet { class Ensemble { private: bool fitted; long n_models; Ensemble& build(vector& features, string className, map>& states); protected: vector> models; int m, n; // m: number of samples, n: number of features Tensor X; vector> Xv; Tensor y; vector yv; Tensor dataset; Metrics metrics; vector features; string className; map> states; void virtual train() = 0; vector voting(Tensor& y_pred); public: Ensemble(); virtual ~Ensemble() = default; Ensemble& fit(vector>& X, vector& y, vector& features, string className, map>& states); Tensor predict(Tensor& X); vector predict(vector>& X); float score(Tensor& X, Tensor& y); float score(vector>& X, vector& y); vector show(); vector graph(string title); }; } #endif