Implement Random Forest nodes/leaves/depth
This commit is contained in:
parent
4addaefb47
commit
d06bf187b2
@ -38,6 +38,10 @@ namespace pywrap {
|
|||||||
{
|
{
|
||||||
return pyWrap->callMethodString(id, method);
|
return pyWrap->callMethodString(id, method);
|
||||||
}
|
}
|
||||||
|
int PyClassifier::callMethodSumOfItems(const std::string& method) const
|
||||||
|
{
|
||||||
|
return pyWrap->callMethodSumOfItems(id, method);
|
||||||
|
}
|
||||||
int PyClassifier::callMethodInt(const std::string& method) const
|
int PyClassifier::callMethodInt(const std::string& method) const
|
||||||
{
|
{
|
||||||
return pyWrap->callMethodInt(id, method);
|
return pyWrap->callMethodInt(id, method);
|
||||||
|
@ -29,9 +29,9 @@ namespace pywrap {
|
|||||||
float score(torch::Tensor& X, torch::Tensor& y) override;
|
float score(torch::Tensor& X, torch::Tensor& y) override;
|
||||||
std::string version();
|
std::string version();
|
||||||
std::string callMethodString(const std::string& method);
|
std::string callMethodString(const std::string& method);
|
||||||
|
int callMethodSumOfItems(const std::string& method) const;
|
||||||
int callMethodInt(const std::string& method) const;
|
int callMethodInt(const std::string& method) const;
|
||||||
std::string getVersion() override { return this->version(); };
|
std::string getVersion() override { return this->version(); };
|
||||||
// TODO: Implement these 3 methods
|
|
||||||
int getNumberOfNodes() const override { return 0; };
|
int getNumberOfNodes() const override { return 0; };
|
||||||
int getNumberOfEdges() const override { return 0; };
|
int getNumberOfEdges() const override { return 0; };
|
||||||
int getNumberOfStates() const override { return 0; };
|
int getNumberOfStates() const override { return 0; };
|
||||||
|
@ -145,6 +145,45 @@ namespace pywrap {
|
|||||||
{
|
{
|
||||||
return callMethodString(id, "version");
|
return callMethodString(id, "version");
|
||||||
}
|
}
|
||||||
|
int PyWrap::callMethodSumOfItems(const clfId_t id, const std::string& method)
|
||||||
|
{
|
||||||
|
// Call method on each estimator and sum the results (made for RandomForest)
|
||||||
|
PyObject* instance = getClass(id);
|
||||||
|
PyObject* estimators = PyObject_GetAttrString(instance, "estimators_");
|
||||||
|
if (estimators == nullptr) {
|
||||||
|
errorAbort("Failed to get attribute: " + method);
|
||||||
|
}
|
||||||
|
int sumOfItems = 0;
|
||||||
|
Py_ssize_t len = PyList_Size(estimators);
|
||||||
|
for (Py_ssize_t i = 0; i < len; i++) {
|
||||||
|
PyObject* estimator = PyList_GetItem(estimators, i);
|
||||||
|
PyObject* result;
|
||||||
|
if (method == "node_count") {
|
||||||
|
PyObject* owner = PyObject_GetAttrString(estimator, "tree_");
|
||||||
|
if (owner == nullptr) {
|
||||||
|
Py_XDECREF(estimators);
|
||||||
|
errorAbort("Failed to get attribute tree_ for: " + method);
|
||||||
|
}
|
||||||
|
result = PyObject_GetAttrString(owner, method.c_str());
|
||||||
|
if (result == nullptr) {
|
||||||
|
Py_XDECREF(estimators);
|
||||||
|
Py_XDECREF(owner);
|
||||||
|
errorAbort("Failed to get attribute node_count: " + method);
|
||||||
|
}
|
||||||
|
Py_DECREF(owner);
|
||||||
|
} else {
|
||||||
|
result = PyObject_CallMethod(estimator, method.c_str(), nullptr);
|
||||||
|
if (result == nullptr) {
|
||||||
|
Py_XDECREF(estimators);
|
||||||
|
errorAbort("Failed to call method: " + method);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sumOfItems += PyLong_AsLong(result);
|
||||||
|
Py_DECREF(result);
|
||||||
|
}
|
||||||
|
Py_DECREF(estimators);
|
||||||
|
return sumOfItems;
|
||||||
|
}
|
||||||
void PyWrap::setHyperparameters(const clfId_t id, const json& hyperparameters)
|
void PyWrap::setHyperparameters(const clfId_t id, const json& hyperparameters)
|
||||||
{
|
{
|
||||||
// Set hyperparameters as attributes of the class
|
// Set hyperparameters as attributes of the class
|
||||||
|
@ -27,6 +27,7 @@ namespace pywrap {
|
|||||||
int callMethodInt(const clfId_t id, const std::string& method);
|
int callMethodInt(const clfId_t id, const std::string& method);
|
||||||
std::string sklearnVersion();
|
std::string sklearnVersion();
|
||||||
std::string version(const clfId_t id);
|
std::string version(const clfId_t id);
|
||||||
|
int callMethodSumOfItems(const clfId_t id, const std::string& method);
|
||||||
void setHyperparameters(const clfId_t id, const json& hyperparameters);
|
void setHyperparameters(const clfId_t id, const json& hyperparameters);
|
||||||
void fit(const clfId_t id, CPyObject& X, CPyObject& y);
|
void fit(const clfId_t id, CPyObject& X, CPyObject& y);
|
||||||
PyObject* predict(const clfId_t id, CPyObject& X);
|
PyObject* predict(const clfId_t id, CPyObject& X);
|
||||||
|
@ -5,4 +5,16 @@ namespace pywrap {
|
|||||||
{
|
{
|
||||||
validHyperparameters = { "n_estimators", "n_jobs", "random_state" };
|
validHyperparameters = { "n_estimators", "n_jobs", "random_state" };
|
||||||
}
|
}
|
||||||
|
int RandomForest::getNumberOfEdges() const
|
||||||
|
{
|
||||||
|
return callMethodSumOfItems("get_n_leaves");
|
||||||
|
}
|
||||||
|
int RandomForest::getNumberOfStates() const
|
||||||
|
{
|
||||||
|
return callMethodSumOfItems("get_depth");
|
||||||
|
}
|
||||||
|
int RandomForest::getNumberOfNodes() const
|
||||||
|
{
|
||||||
|
return callMethodSumOfItems("node_count");
|
||||||
|
}
|
||||||
} /* namespace pywrap */
|
} /* namespace pywrap */
|
@ -7,6 +7,9 @@ namespace pywrap {
|
|||||||
public:
|
public:
|
||||||
RandomForest();
|
RandomForest();
|
||||||
~RandomForest() = default;
|
~RandomForest() = default;
|
||||||
|
int getNumberOfEdges() const override;
|
||||||
|
int getNumberOfStates() const override;
|
||||||
|
int getNumberOfNodes() const override;
|
||||||
};
|
};
|
||||||
} /* namespace pywrap */
|
} /* namespace pywrap */
|
||||||
#endif /* RANDOMFOREST_H */
|
#endif /* RANDOMFOREST_H */
|
Loading…
Reference in New Issue
Block a user