Refactor version method in PyClassifier

This commit is contained in:
Ricardo Montañana Gómez 2023-11-13 13:59:06 +01:00
parent 431b3a3aa5
commit 69ad660040
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
7 changed files with 14 additions and 18 deletions

View File

@ -1,5 +1,4 @@
#include "AODELd.h"
#include "Models.h"
namespace bayesnet {
AODELd::AODELd() : Ensemble(), Proposal(dataset, features, className) {}

View File

@ -2,7 +2,7 @@
namespace pywrap {
namespace bp = boost::python;
namespace np = boost::python::numpy;
PyClassifier::PyClassifier(const std::string& module, const std::string& className) : module(module), className(className), fitted(false)
PyClassifier::PyClassifier(const std::string& module, const std::string& className, bool sklearn) : module(module), className(className), sklearn(sklearn), fitted(false)
{
// This id allows to have more than one instance of the same module/class
id = reinterpret_cast<clfId_t>(this);
@ -29,12 +29,11 @@ namespace pywrap {
}
std::string PyClassifier::version()
{
if (sklearn) {
return pyWrap->sklearnVersion();
}
return pyWrap->version(id);
}
std::string PyClassifier::sklearnVersion()
{
return pyWrap->sklearnVersion();
}
std::string PyClassifier::callMethodString(const std::string& method)
{
return pyWrap->callMethodString(id, method);

View File

@ -15,7 +15,7 @@
namespace pywrap {
class PyClassifier : public bayesnet::BaseClassifier {
public:
PyClassifier(const std::string& module, const std::string& className);
PyClassifier(const std::string& module, const std::string& className, const bool sklearn = false);
virtual ~PyClassifier();
PyClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override { return *this; };
// X is nxm tensor, y is nx1 tensor
@ -29,7 +29,6 @@ namespace pywrap {
float score(torch::Tensor& X, torch::Tensor& y) override;
void setHyperparameters(nlohmann::json& hyperparameters) override;
std::string version();
std::string sklearnVersion();
std::string callMethodString(const std::string& method);
std::string getVersion() override { return this->version(); };
int getNumberOfNodes()const override { return 0; };
@ -48,6 +47,7 @@ namespace pywrap {
PyWrap* pyWrap;
std::string module;
std::string className;
bool sklearn;
clfId_t id;
bool fitted;
};

View File

@ -1,8 +1,11 @@
#include "RandomForest.h"
namespace pywrap {
std::string RandomForest::version()
void RandomForest::setHyperparameters(nlohmann::json& hyperparameters)
{
return sklearnVersion();
// Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "n_estimators", "n_jobs", "random_state" };
checkHyperparameters(validKeys, hyperparameters);
this->hyperparameters = hyperparameters;
}
} /* namespace pywrap */

View File

@ -5,9 +5,9 @@
namespace pywrap {
class RandomForest : public PyClassifier {
public:
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier") {};
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier", true) {};
~RandomForest() = default;
std::string version();
void setHyperparameters(nlohmann::json& hyperparameters) override;
};
} /* namespace pywrap */
#endif /* RANDOMFOREST_H */

View File

@ -1,10 +1,6 @@
#include "SVC.h"
namespace pywrap {
std::string SVC::version()
{
return sklearnVersion();
}
void SVC::setHyperparameters(nlohmann::json& hyperparameters)
{
// Check if hyperparameters are valid

View File

@ -5,9 +5,8 @@
namespace pywrap {
class SVC : public PyClassifier {
public:
SVC() : PyClassifier("sklearn.svm", "SVC") {};
SVC() : PyClassifier("sklearn.svm", "SVC", true) {};
~SVC() = default;
std::string version();
void setHyperparameters(nlohmann::json& hyperparameters) override;
};