#ifndef PYCLASSIFER_H #define PYCLASSIFER_H #include "boost/python/detail/wrap_python.hpp" #include #include #include #include #include #include #include #include "PyWrap.h" #include "Classifier.h" #include "TypeId.h" namespace pywrap { class PyClassifier : public Classifier { public: PyClassifier(const std::string& module, const std::string& className); virtual ~PyClassifier(); PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states); PyClassifier& fit(torch::Tensor& X, torch::Tensor& y); torch::Tensor predict(torch::Tensor& X); double score(torch::Tensor& X, torch::Tensor& y); std::string version(); std::string callMethodString(const std::string& method); void setHyperparameters(const nlohmann::json& hyperparameters) override; protected: void checkHyperparameters(const std::vector& validKeys, const nlohmann::json& hyperparameters); nlohmann::json hyperparameters; private: PyWrap* pyWrap; std::string module; std::string className; clfId_t id; bool fitted; }; } /* namespace pywrap */ #endif /* PYCLASSIFER_H */