From 4addaefb474afc9bbcdec93fcee0a4b7c8377650 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 27 Nov 2023 22:34:34 +0100 Subject: [PATCH] Implement sklearn version in PyWrap --- src/PyClassifiers/PyWrap.cc | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/PyClassifiers/PyWrap.cc b/src/PyClassifiers/PyWrap.cc index 0250731..6167156 100644 --- a/src/PyClassifiers/PyWrap.cc +++ b/src/PyClassifiers/PyWrap.cc @@ -127,10 +127,19 @@ namespace pywrap { } std::string PyWrap::sklearnVersion() { - return "1.0"; - // CPyObject data = PyRun_SimpleString("import sklearn;return sklearn.__version__"); - // std::string result = PyUnicode_AsUTF8(data); - // return result; + PyObject* sklearnModule = PyImport_ImportModule("sklearn"); + if (sklearnModule == nullptr) { + errorAbort("Couldn't import sklearn"); + } + PyObject* versionAttr = PyObject_GetAttrString(sklearnModule, "__version__"); + if (versionAttr == nullptr || !PyUnicode_Check(versionAttr)) { + Py_XDECREF(sklearnModule); + errorAbort("Couldn't get sklearn version"); + } + std::string result = PyUnicode_AsUTF8(versionAttr); + Py_XDECREF(versionAttr); + Py_XDECREF(sklearnModule); + return result; } std::string PyWrap::version(const clfId_t id) {