Implement Random Forest nodes/leaves/depth
This commit is contained in:
parent
4addaefb47
commit
d06bf187b2
@ -38,6 +38,10 @@ namespace pywrap {
|
||||
{
|
||||
return pyWrap->callMethodString(id, method);
|
||||
}
|
||||
int PyClassifier::callMethodSumOfItems(const std::string& method) const
|
||||
{
|
||||
return pyWrap->callMethodSumOfItems(id, method);
|
||||
}
|
||||
int PyClassifier::callMethodInt(const std::string& method) const
|
||||
{
|
||||
return pyWrap->callMethodInt(id, method);
|
||||
|
@ -29,9 +29,9 @@ namespace pywrap {
|
||||
float score(torch::Tensor& X, torch::Tensor& y) override;
|
||||
std::string version();
|
||||
std::string callMethodString(const std::string& method);
|
||||
int callMethodSumOfItems(const std::string& method) const;
|
||||
int callMethodInt(const std::string& method) const;
|
||||
std::string getVersion() override { return this->version(); };
|
||||
// TODO: Implement these 3 methods
|
||||
int getNumberOfNodes() const override { return 0; };
|
||||
int getNumberOfEdges() const override { return 0; };
|
||||
int getNumberOfStates() const override { return 0; };
|
||||
|
@ -145,6 +145,45 @@ namespace pywrap {
|
||||
{
|
||||
return callMethodString(id, "version");
|
||||
}
|
||||
int PyWrap::callMethodSumOfItems(const clfId_t id, const std::string& method)
|
||||
{
|
||||
// Call method on each estimator and sum the results (made for RandomForest)
|
||||
PyObject* instance = getClass(id);
|
||||
PyObject* estimators = PyObject_GetAttrString(instance, "estimators_");
|
||||
if (estimators == nullptr) {
|
||||
errorAbort("Failed to get attribute: " + method);
|
||||
}
|
||||
int sumOfItems = 0;
|
||||
Py_ssize_t len = PyList_Size(estimators);
|
||||
for (Py_ssize_t i = 0; i < len; i++) {
|
||||
PyObject* estimator = PyList_GetItem(estimators, i);
|
||||
PyObject* result;
|
||||
if (method == "node_count") {
|
||||
PyObject* owner = PyObject_GetAttrString(estimator, "tree_");
|
||||
if (owner == nullptr) {
|
||||
Py_XDECREF(estimators);
|
||||
errorAbort("Failed to get attribute tree_ for: " + method);
|
||||
}
|
||||
result = PyObject_GetAttrString(owner, method.c_str());
|
||||
if (result == nullptr) {
|
||||
Py_XDECREF(estimators);
|
||||
Py_XDECREF(owner);
|
||||
errorAbort("Failed to get attribute node_count: " + method);
|
||||
}
|
||||
Py_DECREF(owner);
|
||||
} else {
|
||||
result = PyObject_CallMethod(estimator, method.c_str(), nullptr);
|
||||
if (result == nullptr) {
|
||||
Py_XDECREF(estimators);
|
||||
errorAbort("Failed to call method: " + method);
|
||||
}
|
||||
}
|
||||
sumOfItems += PyLong_AsLong(result);
|
||||
Py_DECREF(result);
|
||||
}
|
||||
Py_DECREF(estimators);
|
||||
return sumOfItems;
|
||||
}
|
||||
void PyWrap::setHyperparameters(const clfId_t id, const json& hyperparameters)
|
||||
{
|
||||
// Set hyperparameters as attributes of the class
|
||||
|
@ -27,6 +27,7 @@ namespace pywrap {
|
||||
int callMethodInt(const clfId_t id, const std::string& method);
|
||||
std::string sklearnVersion();
|
||||
std::string version(const clfId_t id);
|
||||
int callMethodSumOfItems(const clfId_t id, const std::string& method);
|
||||
void setHyperparameters(const clfId_t id, const json& hyperparameters);
|
||||
void fit(const clfId_t id, CPyObject& X, CPyObject& y);
|
||||
PyObject* predict(const clfId_t id, CPyObject& X);
|
||||
|
@ -5,4 +5,16 @@ namespace pywrap {
|
||||
{
|
||||
validHyperparameters = { "n_estimators", "n_jobs", "random_state" };
|
||||
}
|
||||
int RandomForest::getNumberOfEdges() const
|
||||
{
|
||||
return callMethodSumOfItems("get_n_leaves");
|
||||
}
|
||||
int RandomForest::getNumberOfStates() const
|
||||
{
|
||||
return callMethodSumOfItems("get_depth");
|
||||
}
|
||||
int RandomForest::getNumberOfNodes() const
|
||||
{
|
||||
return callMethodSumOfItems("node_count");
|
||||
}
|
||||
} /* namespace pywrap */
|
@ -7,6 +7,9 @@ namespace pywrap {
|
||||
public:
|
||||
RandomForest();
|
||||
~RandomForest() = default;
|
||||
int getNumberOfEdges() const override;
|
||||
int getNumberOfStates() const override;
|
||||
int getNumberOfNodes() const override;
|
||||
};
|
||||
} /* namespace pywrap */
|
||||
#endif /* RANDOMFOREST_H */
|
Loading…
Reference in New Issue
Block a user