#ifndef CLASSIFIERS_H #define CLASSIFIERS_H #include #include "Network.h" #include "Metrics.hpp" using namespace std; using namespace torch; namespace bayesnet { class BaseClassifier { private: bool fitted; BaseClassifier& build(vector& features, string className, map>& states); protected: Network model; int m, n; // m: number of samples, n: number of features Tensor X; vector> Xv; Tensor y; vector yv; Tensor dataset; Metrics metrics; vector features; string className; map> states; void checkFitParameters(); virtual void train() = 0; public: BaseClassifier(Network model); virtual ~BaseClassifier() = default; BaseClassifier& fit(vector>& X, vector& y, vector& features, string className, map>& states); void addNodes(); Tensor predict(Tensor& X); vector predict(vector>& X); float score(Tensor& X, Tensor& y); float score(vector>& X, vector& y); vector show(); }; } #endif