From 69ad660040017ace93a981a88ab8c20df55c6699 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 13 Nov 2023 13:59:06 +0100 Subject: [PATCH] Refactor version method in PyClassifier --- src/BayesNet/AODELd.cc | 1 - src/PyClassifiers/PyClassifier.cc | 9 ++++----- src/PyClassifiers/PyClassifier.h | 4 ++-- src/PyClassifiers/RandomForest.cc | 7 +++++-- src/PyClassifiers/RandomForest.h | 4 ++-- src/PyClassifiers/SVC.cc | 4 ---- src/PyClassifiers/SVC.h | 3 +-- 7 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/BayesNet/AODELd.cc b/src/BayesNet/AODELd.cc index fc899a9..776e37c 100644 --- a/src/BayesNet/AODELd.cc +++ b/src/BayesNet/AODELd.cc @@ -1,5 +1,4 @@ #include "AODELd.h" -#include "Models.h" namespace bayesnet { AODELd::AODELd() : Ensemble(), Proposal(dataset, features, className) {} diff --git a/src/PyClassifiers/PyClassifier.cc b/src/PyClassifiers/PyClassifier.cc index 779c64f..9e9f75a 100644 --- a/src/PyClassifiers/PyClassifier.cc +++ b/src/PyClassifiers/PyClassifier.cc @@ -2,7 +2,7 @@ namespace pywrap { namespace bp = boost::python; namespace np = boost::python::numpy; - PyClassifier::PyClassifier(const std::string& module, const std::string& className) : module(module), className(className), fitted(false) + PyClassifier::PyClassifier(const std::string& module, const std::string& className, bool sklearn) : module(module), className(className), sklearn(sklearn), fitted(false) { // This id allows to have more than one instance of the same module/class id = reinterpret_cast(this); @@ -29,12 +29,11 @@ namespace pywrap { } std::string PyClassifier::version() { + if (sklearn) { + return pyWrap->sklearnVersion(); + } return pyWrap->version(id); } - std::string PyClassifier::sklearnVersion() - { - return pyWrap->sklearnVersion(); - } std::string PyClassifier::callMethodString(const std::string& method) { return pyWrap->callMethodString(id, method); diff --git a/src/PyClassifiers/PyClassifier.h b/src/PyClassifiers/PyClassifier.h index 32f2b3d..ab98583 100644 --- a/src/PyClassifiers/PyClassifier.h +++ b/src/PyClassifiers/PyClassifier.h @@ -15,7 +15,7 @@ namespace pywrap { class PyClassifier : public bayesnet::BaseClassifier { public: - PyClassifier(const std::string& module, const std::string& className); + PyClassifier(const std::string& module, const std::string& className, const bool sklearn = false); virtual ~PyClassifier(); PyClassifier& fit(std::vector>& X, std::vector& y, const std::vector& features, const std::string& className, std::map>& states) override { return *this; }; // X is nxm tensor, y is nx1 tensor @@ -29,7 +29,6 @@ namespace pywrap { float score(torch::Tensor& X, torch::Tensor& y) override; void setHyperparameters(nlohmann::json& hyperparameters) override; std::string version(); - std::string sklearnVersion(); std::string callMethodString(const std::string& method); std::string getVersion() override { return this->version(); }; int getNumberOfNodes()const override { return 0; }; @@ -48,6 +47,7 @@ namespace pywrap { PyWrap* pyWrap; std::string module; std::string className; + bool sklearn; clfId_t id; bool fitted; }; diff --git a/src/PyClassifiers/RandomForest.cc b/src/PyClassifiers/RandomForest.cc index dd0be1f..3ba2424 100644 --- a/src/PyClassifiers/RandomForest.cc +++ b/src/PyClassifiers/RandomForest.cc @@ -1,8 +1,11 @@ #include "RandomForest.h" namespace pywrap { - std::string RandomForest::version() + void RandomForest::setHyperparameters(nlohmann::json& hyperparameters) { - return sklearnVersion(); + // Check if hyperparameters are valid + const std::vector validKeys = { "n_estimators", "n_jobs", "random_state" }; + checkHyperparameters(validKeys, hyperparameters); + this->hyperparameters = hyperparameters; } } /* namespace pywrap */ \ No newline at end of file diff --git a/src/PyClassifiers/RandomForest.h b/src/PyClassifiers/RandomForest.h index ad906c1..a6b2162 100644 --- a/src/PyClassifiers/RandomForest.h +++ b/src/PyClassifiers/RandomForest.h @@ -5,9 +5,9 @@ namespace pywrap { class RandomForest : public PyClassifier { public: - RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier") {}; + RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier", true) {}; ~RandomForest() = default; - std::string version(); + void setHyperparameters(nlohmann::json& hyperparameters) override; }; } /* namespace pywrap */ #endif /* RANDOMFOREST_H */ \ No newline at end of file diff --git a/src/PyClassifiers/SVC.cc b/src/PyClassifiers/SVC.cc index 5734700..2ed9c3a 100644 --- a/src/PyClassifiers/SVC.cc +++ b/src/PyClassifiers/SVC.cc @@ -1,10 +1,6 @@ #include "SVC.h" namespace pywrap { - std::string SVC::version() - { - return sklearnVersion(); - } void SVC::setHyperparameters(nlohmann::json& hyperparameters) { // Check if hyperparameters are valid diff --git a/src/PyClassifiers/SVC.h b/src/PyClassifiers/SVC.h index 1d6ac42..d62bbbc 100644 --- a/src/PyClassifiers/SVC.h +++ b/src/PyClassifiers/SVC.h @@ -5,9 +5,8 @@ namespace pywrap { class SVC : public PyClassifier { public: - SVC() : PyClassifier("sklearn.svm", "SVC") {}; + SVC() : PyClassifier("sklearn.svm", "SVC", true) {}; ~SVC() = default; - std::string version(); void setHyperparameters(nlohmann::json& hyperparameters) override; };