#ifndef PYCLASSIFIER_H #define PYCLASSIFIER_H #include #include #include #include #include "boost/python/detail/wrap_python.hpp" #include #include #include #include "bayesnet/classifiers/Classifier.h" #include "PyWrap.h" #include "TypeId.h" namespace pywrap { class PyClassifier : public bayesnet::BaseClassifier { public: PyClassifier(const std::string& module, const std::string& className, const bool sklearn = false); virtual ~PyClassifier(); PyClassifier& fit(std::vector>& X, std::vector& y, const std::vector& features, const std::string& className, std::map>& states) override { return *this; }; // X is nxm tensor, y is nx1 tensor PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states) override; PyClassifier& fit(torch::Tensor& X, torch::Tensor& y); PyClassifier& fit(torch::Tensor& dataset, const std::vector& features, const std::string& className, std::map>& states) override { return *this; }; PyClassifier& fit(torch::Tensor& dataset, const std::vector& features, const std::string& className, std::map>& states, const torch::Tensor& weights) override { return *this; }; torch::Tensor predict(torch::Tensor& X) override; std::vector predict(std::vector>& X) override { return std::vector(); }; // Not implemented torch::Tensor predict_proba(torch::Tensor& X) override { return torch::zeros({ 0, 0 }); } // Not implemented std::vector> predict_proba(std::vector>& X) override { return std::vector>(); }; // Not implemented float score(std::vector>& X, std::vector& y) override { return 0.0; }; // Not implemented float score(torch::Tensor& X, torch::Tensor& y) override; int getClassNumStates() const override { return 0; }; std::string version(); std::string callMethodString(const std::string& method); int callMethodSumOfItems(const std::string& method) const; int callMethodInt(const std::string& method) const; std::string getVersion() override { return this->version(); }; int getNumberOfNodes() const override { return 0; }; int getNumberOfEdges() const override { return 0; }; int getNumberOfStates() const override { return 0; }; std::vector show() const override { return std::vector(); } std::vector graph(const std::string& title = "") const override { return std::vector(); } bayesnet::status_t getStatus() const override { return bayesnet::NORMAL; }; std::vector topological_order() override { return std::vector(); } void dump_cpt() const override {}; std::vector getNotes() const override { return notes; }; void setHyperparameters(const nlohmann::json& hyperparameters) override; protected: nlohmann::json hyperparameters; void trainModel(const torch::Tensor& weights) override {}; std::vector notes; private: PyWrap* pyWrap; std::string module; std::string className; bool sklearn; clfId_t id; bool fitted; }; } /* namespace pywrap */ #endif /* PYCLASSIFIER_H */