Refactor version method in PyClassifier

This commit is contained in:
Ricardo Montañana Gómez 2023-11-13 13:59:06 +01:00
parent 431b3a3aa5
commit 69ad660040
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
7 changed files with 14 additions and 18 deletions

View File

@ -1,5 +1,4 @@
#include "AODELd.h" #include "AODELd.h"
#include "Models.h"
namespace bayesnet { namespace bayesnet {
AODELd::AODELd() : Ensemble(), Proposal(dataset, features, className) {} AODELd::AODELd() : Ensemble(), Proposal(dataset, features, className) {}

View File

@ -2,7 +2,7 @@
namespace pywrap { namespace pywrap {
namespace bp = boost::python; namespace bp = boost::python;
namespace np = boost::python::numpy; namespace np = boost::python::numpy;
PyClassifier::PyClassifier(const std::string& module, const std::string& className) : module(module), className(className), fitted(false) PyClassifier::PyClassifier(const std::string& module, const std::string& className, bool sklearn) : module(module), className(className), sklearn(sklearn), fitted(false)
{ {
// This id allows to have more than one instance of the same module/class // This id allows to have more than one instance of the same module/class
id = reinterpret_cast<clfId_t>(this); id = reinterpret_cast<clfId_t>(this);
@ -29,12 +29,11 @@ namespace pywrap {
} }
std::string PyClassifier::version() std::string PyClassifier::version()
{ {
if (sklearn) {
return pyWrap->sklearnVersion();
}
return pyWrap->version(id); return pyWrap->version(id);
} }
std::string PyClassifier::sklearnVersion()
{
return pyWrap->sklearnVersion();
}
std::string PyClassifier::callMethodString(const std::string& method) std::string PyClassifier::callMethodString(const std::string& method)
{ {
return pyWrap->callMethodString(id, method); return pyWrap->callMethodString(id, method);

View File

@ -15,7 +15,7 @@
namespace pywrap { namespace pywrap {
class PyClassifier : public bayesnet::BaseClassifier { class PyClassifier : public bayesnet::BaseClassifier {
public: public:
PyClassifier(const std::string& module, const std::string& className); PyClassifier(const std::string& module, const std::string& className, const bool sklearn = false);
virtual ~PyClassifier(); virtual ~PyClassifier();
PyClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override { return *this; }; PyClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override { return *this; };
// X is nxm tensor, y is nx1 tensor // X is nxm tensor, y is nx1 tensor
@ -29,7 +29,6 @@ namespace pywrap {
float score(torch::Tensor& X, torch::Tensor& y) override; float score(torch::Tensor& X, torch::Tensor& y) override;
void setHyperparameters(nlohmann::json& hyperparameters) override; void setHyperparameters(nlohmann::json& hyperparameters) override;
std::string version(); std::string version();
std::string sklearnVersion();
std::string callMethodString(const std::string& method); std::string callMethodString(const std::string& method);
std::string getVersion() override { return this->version(); }; std::string getVersion() override { return this->version(); };
int getNumberOfNodes()const override { return 0; }; int getNumberOfNodes()const override { return 0; };
@ -48,6 +47,7 @@ namespace pywrap {
PyWrap* pyWrap; PyWrap* pyWrap;
std::string module; std::string module;
std::string className; std::string className;
bool sklearn;
clfId_t id; clfId_t id;
bool fitted; bool fitted;
}; };

View File

@ -1,8 +1,11 @@
#include "RandomForest.h" #include "RandomForest.h"
namespace pywrap { namespace pywrap {
std::string RandomForest::version() void RandomForest::setHyperparameters(nlohmann::json& hyperparameters)
{ {
return sklearnVersion(); // Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "n_estimators", "n_jobs", "random_state" };
checkHyperparameters(validKeys, hyperparameters);
this->hyperparameters = hyperparameters;
} }
} /* namespace pywrap */ } /* namespace pywrap */

View File

@ -5,9 +5,9 @@
namespace pywrap { namespace pywrap {
class RandomForest : public PyClassifier { class RandomForest : public PyClassifier {
public: public:
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier") {}; RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier", true) {};
~RandomForest() = default; ~RandomForest() = default;
std::string version(); void setHyperparameters(nlohmann::json& hyperparameters) override;
}; };
} /* namespace pywrap */ } /* namespace pywrap */
#endif /* RANDOMFOREST_H */ #endif /* RANDOMFOREST_H */

View File

@ -1,10 +1,6 @@
#include "SVC.h" #include "SVC.h"
namespace pywrap { namespace pywrap {
std::string SVC::version()
{
return sklearnVersion();
}
void SVC::setHyperparameters(nlohmann::json& hyperparameters) void SVC::setHyperparameters(nlohmann::json& hyperparameters)
{ {
// Check if hyperparameters are valid // Check if hyperparameters are valid

View File

@ -5,9 +5,8 @@
namespace pywrap { namespace pywrap {
class SVC : public PyClassifier { class SVC : public PyClassifier {
public: public:
SVC() : PyClassifier("sklearn.svm", "SVC") {}; SVC() : PyClassifier("sklearn.svm", "SVC", true) {};
~SVC() = default; ~SVC() = default;
std::string version();
void setHyperparameters(nlohmann::json& hyperparameters) override; void setHyperparameters(nlohmann::json& hyperparameters) override;
}; };