Refactor version method in PyClassifier
This commit is contained in:
parent
431b3a3aa5
commit
69ad660040
@ -1,5 +1,4 @@
|
||||
#include "AODELd.h"
|
||||
#include "Models.h"
|
||||
|
||||
namespace bayesnet {
|
||||
AODELd::AODELd() : Ensemble(), Proposal(dataset, features, className) {}
|
||||
|
@ -2,7 +2,7 @@
|
||||
namespace pywrap {
|
||||
namespace bp = boost::python;
|
||||
namespace np = boost::python::numpy;
|
||||
PyClassifier::PyClassifier(const std::string& module, const std::string& className) : module(module), className(className), fitted(false)
|
||||
PyClassifier::PyClassifier(const std::string& module, const std::string& className, bool sklearn) : module(module), className(className), sklearn(sklearn), fitted(false)
|
||||
{
|
||||
// This id allows to have more than one instance of the same module/class
|
||||
id = reinterpret_cast<clfId_t>(this);
|
||||
@ -29,12 +29,11 @@ namespace pywrap {
|
||||
}
|
||||
std::string PyClassifier::version()
|
||||
{
|
||||
if (sklearn) {
|
||||
return pyWrap->sklearnVersion();
|
||||
}
|
||||
return pyWrap->version(id);
|
||||
}
|
||||
std::string PyClassifier::sklearnVersion()
|
||||
{
|
||||
return pyWrap->sklearnVersion();
|
||||
}
|
||||
std::string PyClassifier::callMethodString(const std::string& method)
|
||||
{
|
||||
return pyWrap->callMethodString(id, method);
|
||||
|
@ -15,7 +15,7 @@
|
||||
namespace pywrap {
|
||||
class PyClassifier : public bayesnet::BaseClassifier {
|
||||
public:
|
||||
PyClassifier(const std::string& module, const std::string& className);
|
||||
PyClassifier(const std::string& module, const std::string& className, const bool sklearn = false);
|
||||
virtual ~PyClassifier();
|
||||
PyClassifier& fit(std::vector<std::vector<int>>& X, std::vector<int>& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override { return *this; };
|
||||
// X is nxm tensor, y is nx1 tensor
|
||||
@ -29,7 +29,6 @@ namespace pywrap {
|
||||
float score(torch::Tensor& X, torch::Tensor& y) override;
|
||||
void setHyperparameters(nlohmann::json& hyperparameters) override;
|
||||
std::string version();
|
||||
std::string sklearnVersion();
|
||||
std::string callMethodString(const std::string& method);
|
||||
std::string getVersion() override { return this->version(); };
|
||||
int getNumberOfNodes()const override { return 0; };
|
||||
@ -48,6 +47,7 @@ namespace pywrap {
|
||||
PyWrap* pyWrap;
|
||||
std::string module;
|
||||
std::string className;
|
||||
bool sklearn;
|
||||
clfId_t id;
|
||||
bool fitted;
|
||||
};
|
||||
|
@ -1,8 +1,11 @@
|
||||
#include "RandomForest.h"
|
||||
|
||||
namespace pywrap {
|
||||
std::string RandomForest::version()
|
||||
void RandomForest::setHyperparameters(nlohmann::json& hyperparameters)
|
||||
{
|
||||
return sklearnVersion();
|
||||
// Check if hyperparameters are valid
|
||||
const std::vector<std::string> validKeys = { "n_estimators", "n_jobs", "random_state" };
|
||||
checkHyperparameters(validKeys, hyperparameters);
|
||||
this->hyperparameters = hyperparameters;
|
||||
}
|
||||
} /* namespace pywrap */
|
@ -5,9 +5,9 @@
|
||||
namespace pywrap {
|
||||
class RandomForest : public PyClassifier {
|
||||
public:
|
||||
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier") {};
|
||||
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier", true) {};
|
||||
~RandomForest() = default;
|
||||
std::string version();
|
||||
void setHyperparameters(nlohmann::json& hyperparameters) override;
|
||||
};
|
||||
} /* namespace pywrap */
|
||||
#endif /* RANDOMFOREST_H */
|
@ -1,10 +1,6 @@
|
||||
#include "SVC.h"
|
||||
|
||||
namespace pywrap {
|
||||
std::string SVC::version()
|
||||
{
|
||||
return sklearnVersion();
|
||||
}
|
||||
void SVC::setHyperparameters(nlohmann::json& hyperparameters)
|
||||
{
|
||||
// Check if hyperparameters are valid
|
||||
|
@ -5,9 +5,8 @@
|
||||
namespace pywrap {
|
||||
class SVC : public PyClassifier {
|
||||
public:
|
||||
SVC() : PyClassifier("sklearn.svm", "SVC") {};
|
||||
SVC() : PyClassifier("sklearn.svm", "SVC", true) {};
|
||||
~SVC() = default;
|
||||
std::string version();
|
||||
void setHyperparameters(nlohmann::json& hyperparameters) override;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user