diff --git a/src/PyClassifiers/ODTE.cc b/src/PyClassifiers/ODTE.cc index 4991bd9..11b1433 100644 --- a/src/PyClassifiers/ODTE.cc +++ b/src/PyClassifiers/ODTE.cc @@ -5,6 +5,18 @@ namespace pywrap { { validHyperparameters = { "n_jobs", "n_estimators", "random_state" }; } + int ODTE::getNumberOfNodes() const + { + return callMethodInt("get_nodes"); + } + int ODTE::getNumberOfEdges() const + { + return callMethodInt("get_leaves"); + } + int ODTE::getNumberOfStates() const + { + return callMethodInt("get_depth"); + } std::string ODTE::graph() { return callMethodString("graph"); diff --git a/src/PyClassifiers/ODTE.h b/src/PyClassifiers/ODTE.h index 9d44b24..0f968f3 100644 --- a/src/PyClassifiers/ODTE.h +++ b/src/PyClassifiers/ODTE.h @@ -8,6 +8,9 @@ namespace pywrap { public: ODTE(); ~ODTE() = default; + int getNumberOfNodes() const override; + int getNumberOfEdges() const override; + int getNumberOfStates() const override; std::string graph(); }; } /* namespace pywrap */ diff --git a/src/PyClassifiers/PyClassifier.cc b/src/PyClassifiers/PyClassifier.cc index 9406166..0a114e1 100644 --- a/src/PyClassifiers/PyClassifier.cc +++ b/src/PyClassifiers/PyClassifier.cc @@ -38,6 +38,10 @@ namespace pywrap { { return pyWrap->callMethodString(id, method); } + int PyClassifier::callMethodInt(const std::string& method) const + { + return pyWrap->callMethodInt(id, method); + } PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y) { if (!fitted && hyperparameters.size() > 0) { diff --git a/src/PyClassifiers/PyClassifier.h b/src/PyClassifiers/PyClassifier.h index fad2b99..7cf3fd2 100644 --- a/src/PyClassifiers/PyClassifier.h +++ b/src/PyClassifiers/PyClassifier.h @@ -29,10 +29,11 @@ namespace pywrap { float score(torch::Tensor& X, torch::Tensor& y) override; std::string version(); std::string callMethodString(const std::string& method); + int callMethodInt(const std::string& method) const; std::string getVersion() override { return this->version(); }; // TODO: Implement these 3 methods - int getNumberOfNodes()const override { return 0; }; - int getNumberOfEdges()const override { return 0; }; + int getNumberOfNodes() const override { return 0; }; + int getNumberOfEdges() const override { return 0; }; int getNumberOfStates() const override { return 0; }; std::vector show() const override { return std::vector(); } std::vector graph(const std::string& title = "") const override { return std::vector(); } diff --git a/src/PyClassifiers/PyWrap.cc b/src/PyClassifiers/PyWrap.cc index 836bcd3..0250731 100644 --- a/src/PyClassifiers/PyWrap.cc +++ b/src/PyClassifiers/PyWrap.cc @@ -110,6 +110,21 @@ namespace pywrap { Py_XDECREF(result); return value; } + int PyWrap::callMethodInt(const clfId_t id, const std::string& method) + { + PyObject* instance = getClass(id); + PyObject* result; + try { + if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL))) + errorAbort("Couldn't call method " + method); + } + catch (const std::exception& e) { + errorAbort(e.what()); + } + int value = PyLong_AsLong(result); + Py_XDECREF(result); + return value; + } std::string PyWrap::sklearnVersion() { return "1.0"; diff --git a/src/PyClassifiers/PyWrap.h b/src/PyClassifiers/PyWrap.h index 7f83c99..0c3be3b 100644 --- a/src/PyClassifiers/PyWrap.h +++ b/src/PyClassifiers/PyWrap.h @@ -24,6 +24,7 @@ namespace pywrap { void operator=(const PyWrap&) = delete; ~PyWrap() = default; std::string callMethodString(const clfId_t id, const std::string& method); + int callMethodInt(const clfId_t id, const std::string& method); std::string sklearnVersion(); std::string version(const clfId_t id); void setHyperparameters(const clfId_t id, const json& hyperparameters); diff --git a/src/PyClassifiers/STree.cc b/src/PyClassifiers/STree.cc index f97ed94..faff2ce 100644 --- a/src/PyClassifiers/STree.cc +++ b/src/PyClassifiers/STree.cc @@ -5,6 +5,18 @@ namespace pywrap { { validHyperparameters = { "C", "kernel", "max_iter", "max_depth", "random_state", "multiclass_strategy", "gamma", "max_features", "degree" }; }; + int STree::getNumberOfNodes() const + { + return callMethodInt("get_nodes"); + } + int STree::getNumberOfEdges() const + { + return callMethodInt("get_leaves"); + } + int STree::getNumberOfStates() const + { + return callMethodInt("get_depth"); + } std::string STree::graph() { return callMethodString("graph"); diff --git a/src/PyClassifiers/STree.h b/src/PyClassifiers/STree.h index 7b0b8e4..7862d3b 100644 --- a/src/PyClassifiers/STree.h +++ b/src/PyClassifiers/STree.h @@ -8,6 +8,9 @@ namespace pywrap { public: STree(); ~STree() = default; + int getNumberOfNodes() const override; + int getNumberOfEdges() const override; + int getNumberOfStates() const override; std::string graph(); }; } /* namespace pywrap */