Refactor version method in PyClassifier

This commit is contained in:
2023-11-13 13:59:06 +01:00
parent 431b3a3aa5
commit 69ad660040
7 changed files with 14 additions and 18 deletions

View File

@@ -15,7 +15,7 @@
namespace pywrap {
class PyClassifier : public bayesnet::BaseClassifier {
public:
PyClassifier(const std::string& module, const std::string& className);
PyClassifier(const std::string& module, const std::string& className, const bool sklearn = false);
virtual ~PyClassifier();
PyClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override { return *this; };
// X is nxm tensor, y is nx1 tensor
@@ -29,7 +29,6 @@ namespace pywrap {
float score(torch::Tensor& X, torch::Tensor& y) override;
void setHyperparameters(nlohmann::json& hyperparameters) override;
std::string version();
std::string sklearnVersion();
std::string callMethodString(const std::string& method);
std::string getVersion() override { return this->version(); };
int getNumberOfNodes()const override { return 0; };
@@ -48,6 +47,7 @@ namespace pywrap {
PyWrap* pyWrap;
std::string module;
std::string className;
bool sklearn;
clfId_t id;
bool fitted;
};