Implement Random Forest nodes/leaves/depth

This commit is contained in:
Ricardo Montañana Gómez 2023-11-28 00:35:38 +01:00
parent 4addaefb47
commit d06bf187b2
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
6 changed files with 60 additions and 1 deletions

View File

@ -38,6 +38,10 @@ namespace pywrap {
{
return pyWrap->callMethodString(id, method);
}
int PyClassifier::callMethodSumOfItems(const std::string& method) const
{
return pyWrap->callMethodSumOfItems(id, method);
}
int PyClassifier::callMethodInt(const std::string& method) const
{
return pyWrap->callMethodInt(id, method);

View File

@ -29,9 +29,9 @@ namespace pywrap {
float score(torch::Tensor& X, torch::Tensor& y) override;
std::string version();
std::string callMethodString(const std::string& method);
int callMethodSumOfItems(const std::string& method) const;
int callMethodInt(const std::string& method) const;
std::string getVersion() override { return this->version(); };
// TODO: Implement these 3 methods
int getNumberOfNodes() const override { return 0; };
int getNumberOfEdges() const override { return 0; };
int getNumberOfStates() const override { return 0; };

View File

@ -145,6 +145,45 @@ namespace pywrap {
{
return callMethodString(id, "version");
}
int PyWrap::callMethodSumOfItems(const clfId_t id, const std::string& method)
{
// Call method on each estimator and sum the results (made for RandomForest)
PyObject* instance = getClass(id);
PyObject* estimators = PyObject_GetAttrString(instance, "estimators_");
if (estimators == nullptr) {
errorAbort("Failed to get attribute: " + method);
}
int sumOfItems = 0;
Py_ssize_t len = PyList_Size(estimators);
for (Py_ssize_t i = 0; i < len; i++) {
PyObject* estimator = PyList_GetItem(estimators, i);
PyObject* result;
if (method == "node_count") {
PyObject* owner = PyObject_GetAttrString(estimator, "tree_");
if (owner == nullptr) {
Py_XDECREF(estimators);
errorAbort("Failed to get attribute tree_ for: " + method);
}
result = PyObject_GetAttrString(owner, method.c_str());
if (result == nullptr) {
Py_XDECREF(estimators);
Py_XDECREF(owner);
errorAbort("Failed to get attribute node_count: " + method);
}
Py_DECREF(owner);
} else {
result = PyObject_CallMethod(estimator, method.c_str(), nullptr);
if (result == nullptr) {
Py_XDECREF(estimators);
errorAbort("Failed to call method: " + method);
}
}
sumOfItems += PyLong_AsLong(result);
Py_DECREF(result);
}
Py_DECREF(estimators);
return sumOfItems;
}
void PyWrap::setHyperparameters(const clfId_t id, const json& hyperparameters)
{
// Set hyperparameters as attributes of the class

View File

@ -27,6 +27,7 @@ namespace pywrap {
int callMethodInt(const clfId_t id, const std::string& method);
std::string sklearnVersion();
std::string version(const clfId_t id);
int callMethodSumOfItems(const clfId_t id, const std::string& method);
void setHyperparameters(const clfId_t id, const json& hyperparameters);
void fit(const clfId_t id, CPyObject& X, CPyObject& y);
PyObject* predict(const clfId_t id, CPyObject& X);

View File

@ -5,4 +5,16 @@ namespace pywrap {
{
validHyperparameters = { "n_estimators", "n_jobs", "random_state" };
}
int RandomForest::getNumberOfEdges() const
{
return callMethodSumOfItems("get_n_leaves");
}
int RandomForest::getNumberOfStates() const
{
return callMethodSumOfItems("get_depth");
}
int RandomForest::getNumberOfNodes() const
{
return callMethodSumOfItems("node_count");
}
} /* namespace pywrap */

View File

@ -7,6 +7,9 @@ namespace pywrap {
public:
RandomForest();
~RandomForest() = default;
int getNumberOfEdges() const override;
int getNumberOfStates() const override;
int getNumberOfNodes() const override;
};
} /* namespace pywrap */
#endif /* RANDOMFOREST_H */