diff --git a/src/PyClassifiers/PyClassifier.cc b/src/PyClassifiers/PyClassifier.cc index 0a114e1..aeb3798 100644 --- a/src/PyClassifiers/PyClassifier.cc +++ b/src/PyClassifiers/PyClassifier.cc @@ -38,6 +38,10 @@ namespace pywrap { { return pyWrap->callMethodString(id, method); } + int PyClassifier::callMethodSumOfItems(const std::string& method) const + { + return pyWrap->callMethodSumOfItems(id, method); + } int PyClassifier::callMethodInt(const std::string& method) const { return pyWrap->callMethodInt(id, method); diff --git a/src/PyClassifiers/PyClassifier.h b/src/PyClassifiers/PyClassifier.h index 7cf3fd2..7260d2e 100644 --- a/src/PyClassifiers/PyClassifier.h +++ b/src/PyClassifiers/PyClassifier.h @@ -29,9 +29,9 @@ namespace pywrap { float score(torch::Tensor& X, torch::Tensor& y) override; std::string version(); std::string callMethodString(const std::string& method); + int callMethodSumOfItems(const std::string& method) const; int callMethodInt(const std::string& method) const; std::string getVersion() override { return this->version(); }; - // TODO: Implement these 3 methods int getNumberOfNodes() const override { return 0; }; int getNumberOfEdges() const override { return 0; }; int getNumberOfStates() const override { return 0; }; diff --git a/src/PyClassifiers/PyWrap.cc b/src/PyClassifiers/PyWrap.cc index 6167156..88a5a9c 100644 --- a/src/PyClassifiers/PyWrap.cc +++ b/src/PyClassifiers/PyWrap.cc @@ -145,6 +145,45 @@ namespace pywrap { { return callMethodString(id, "version"); } + int PyWrap::callMethodSumOfItems(const clfId_t id, const std::string& method) + { + // Call method on each estimator and sum the results (made for RandomForest) + PyObject* instance = getClass(id); + PyObject* estimators = PyObject_GetAttrString(instance, "estimators_"); + if (estimators == nullptr) { + errorAbort("Failed to get attribute: " + method); + } + int sumOfItems = 0; + Py_ssize_t len = PyList_Size(estimators); + for (Py_ssize_t i = 0; i < len; i++) { + PyObject* estimator = PyList_GetItem(estimators, i); + PyObject* result; + if (method == "node_count") { + PyObject* owner = PyObject_GetAttrString(estimator, "tree_"); + if (owner == nullptr) { + Py_XDECREF(estimators); + errorAbort("Failed to get attribute tree_ for: " + method); + } + result = PyObject_GetAttrString(owner, method.c_str()); + if (result == nullptr) { + Py_XDECREF(estimators); + Py_XDECREF(owner); + errorAbort("Failed to get attribute node_count: " + method); + } + Py_DECREF(owner); + } else { + result = PyObject_CallMethod(estimator, method.c_str(), nullptr); + if (result == nullptr) { + Py_XDECREF(estimators); + errorAbort("Failed to call method: " + method); + } + } + sumOfItems += PyLong_AsLong(result); + Py_DECREF(result); + } + Py_DECREF(estimators); + return sumOfItems; + } void PyWrap::setHyperparameters(const clfId_t id, const json& hyperparameters) { // Set hyperparameters as attributes of the class diff --git a/src/PyClassifiers/PyWrap.h b/src/PyClassifiers/PyWrap.h index 0c3be3b..d23b746 100644 --- a/src/PyClassifiers/PyWrap.h +++ b/src/PyClassifiers/PyWrap.h @@ -27,6 +27,7 @@ namespace pywrap { int callMethodInt(const clfId_t id, const std::string& method); std::string sklearnVersion(); std::string version(const clfId_t id); + int callMethodSumOfItems(const clfId_t id, const std::string& method); void setHyperparameters(const clfId_t id, const json& hyperparameters); void fit(const clfId_t id, CPyObject& X, CPyObject& y); PyObject* predict(const clfId_t id, CPyObject& X); diff --git a/src/PyClassifiers/RandomForest.cc b/src/PyClassifiers/RandomForest.cc index a4c3f9f..dfdb8ba 100644 --- a/src/PyClassifiers/RandomForest.cc +++ b/src/PyClassifiers/RandomForest.cc @@ -5,4 +5,16 @@ namespace pywrap { { validHyperparameters = { "n_estimators", "n_jobs", "random_state" }; } + int RandomForest::getNumberOfEdges() const + { + return callMethodSumOfItems("get_n_leaves"); + } + int RandomForest::getNumberOfStates() const + { + return callMethodSumOfItems("get_depth"); + } + int RandomForest::getNumberOfNodes() const + { + return callMethodSumOfItems("node_count"); + } } /* namespace pywrap */ \ No newline at end of file diff --git a/src/PyClassifiers/RandomForest.h b/src/PyClassifiers/RandomForest.h index e22af10..2e1d3d0 100644 --- a/src/PyClassifiers/RandomForest.h +++ b/src/PyClassifiers/RandomForest.h @@ -7,6 +7,9 @@ namespace pywrap { public: RandomForest(); ~RandomForest() = default; + int getNumberOfEdges() const override; + int getNumberOfStates() const override; + int getNumberOfNodes() const override; }; } /* namespace pywrap */ #endif /* RANDOMFOREST_H */ \ No newline at end of file