From 8334f81276769fa17fa4745ca5e7d2663928ef1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Fri, 10 Nov 2023 10:42:01 +0100 Subject: [PATCH] Set attributes --- src/PyClassifier.cc | 1 + src/PyClassifier.h | 1 + src/PyWrap.cc | 103 ++++++++------------------------------------ src/main.cc | 9 ++-- 4 files changed, 26 insertions(+), 88 deletions(-) diff --git a/src/PyClassifier.cc b/src/PyClassifier.cc index 69ee5aa..f75387e 100644 --- a/src/PyClassifier.cc +++ b/src/PyClassifier.cc @@ -5,6 +5,7 @@ namespace pywrap { namespace np = boost::python::numpy; PyClassifier::PyClassifier(const std::string& module, const std::string& className) : module(module), className(className), fitted(false) { + id = reinterpret_cast(&this); pyWrap = PyWrap::GetInstance(); pyWrap->importClass(module, className); } diff --git a/src/PyClassifier.h b/src/PyClassifier.h index 012a085..358d4e7 100644 --- a/src/PyClassifier.h +++ b/src/PyClassifier.h @@ -29,6 +29,7 @@ namespace pywrap { PyWrap* pyWrap; std::string module; std::string className; + uint32_t id; bool fitted; }; } /* namespace pywrap */ diff --git a/src/PyWrap.cc b/src/PyWrap.cc index 7a4851c..24d7b6c 100644 --- a/src/PyWrap.cc +++ b/src/PyWrap.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace pywrap { @@ -112,97 +113,31 @@ namespace pywrap { { return callMethodString(moduleName, className, "version"); } - // void printPyObject(PyObject* obj) - // { - // PyObject* pStr = PyObject_Str(obj); - // const char* str = PyUnicode_AsUTF8(pStr); - // printf("%s\n", str); - // Py_XDECREF(pStr); - // } - - // void printDictionary(PyObject* pDict) - // { - // PyObject* pKeys = PyDict_Keys(pDict); - // Py_ssize_t size = PyList_Size(pKeys); - - // for (Py_ssize_t i = 0; i < size; ++i) { - // PyObject* pKey = PyList_GetItem(pKeys, i); - // PyObject* pValue = PyDict_GetItem(pDict, pKey); - - // printf("%s: ", PyUnicode_AsUTF8(pKey)); - // printPyObject(pValue); - // } - - // Py_XDECREF(pKeys); - // } - void cleanDictionary(PyObject* pDict) - { - PyObject* pKeys = PyDict_Keys(pDict); - Py_ssize_t size = PyList_Size(pKeys); - for (Py_ssize_t i = 0; i < size; ++i) { - PyObject* pKey = PyList_GetItem(pKeys, i); - PyObject* pValue = PyDict_GetItem(pDict, pKey); - Py_XDECREF(pKey); - Py_XDECREF(pValue); - } - Py_XDECREF(pKeys); - Py_XDECREF(pDict); - } void PyWrap::setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters) { - PyObject* args = PyDict_New(); - - // Build dictionary of arguments with a little help of chatGPT + // Set hyperparameters as attributes of the class std::cout << "Building dictionary of arguments" << std::endl; - try { - PyObject* pValue; - for (const auto& [key, value] : hyperparameters.items()) { - std::string type_name; - if (value.type_name() == "string") { - type_name = "s"; - pValue = Py_BuildValue("s", value.get().c_str()); - std::cout << key << " s " << value.get() << std::endl; + PyObject* pValue; + PyObject* instance = getClass(moduleName, className); + for (const auto& [key, value] : hyperparameters.items()) { + std::stringstream oss; + oss << value.type_name(); + if (oss.str() == "string") { + pValue = Py_BuildValue("s", value.get().c_str()); + } else { + if (value.is_number_integer()) { + pValue = Py_BuildValue("i", value.get()); } else { - if (value.is_number_integer()) { - pValue = Py_BuildValue("i", value.get()); - std::cout << key << " i " << value.get() << std::endl; - } else { - pValue = Py_BuildValue("f", value.get()); - std::cout << key << " f " << value.get() << std::endl; - } + pValue = Py_BuildValue("f", value.get()); } - PyDict_SetItemString(args, key.c_str(), pValue); - Py_XDECREF(pValue); } - } - catch (const std::exception& e) { - - Py_DECREF(args); - errorAbort(e.what()); - } - std::cout << "PyDict_Size=" << PyDict_Size(args) << std::endl; - std::cout << "Calling method set_args with" << std::endl; - //printDictionary(args); - Py_INCREF(args); - PyObject* result; - // Call the method with the argument dictionary with a little help of chatGPT - auto instance = getClass(moduleName, className); - try { - if (!(result = PyObject_CallMethod(instance, "set_params", "O", args))) { - //if (!(result = PyObject_Call(instance, PyObject_GetAttrString(instance, "set_params"), args, nullptr))) - std::cout << "Cleaning up because of error" << std::endl; + int res = PyObject_SetAttrString(instance, key.c_str(), pValue); + if (res == -1 && PyErr_Occurred()) { cleanDictionary(args); - errorAbort("Couldn't call method set_args"); + errorAbort("Couldn't set attribute " + key + "=" + value.dump()); } + Py_XDECREF(pValue); } - catch (const std::exception& e) { - std::cout << "Cleaning up because of exception" << std::endl; - cleanDictionary(args); - errorAbort(e.what()); - } - std::cout << "Cleaning up everything went ok!" << std::endl; - cleanDictionary(args); - Py_XDECREF(result); } void PyWrap::fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y) { @@ -216,7 +151,7 @@ namespace pywrap { catch (const std::exception& e) { errorAbort(e.what()); } - Py_XDECREF(result); + // Py_XDECREF(result); } PyObject* PyWrap::predict(const std::string& moduleName, const std::string& className, CPyObject& X) @@ -247,7 +182,7 @@ namespace pywrap { errorAbort(e.what()); } double resultValue = PyFloat_AsDouble(result); - Py_XDECREF(result); + // Py_XDECREF(result); return resultValue; } } \ No newline at end of file diff --git a/src/main.cc b/src/main.cc index 8350575..bee7d06 100644 --- a/src/main.cc +++ b/src/main.cc @@ -53,16 +53,16 @@ int main(int argc, char* argv[]) cout << "X: " << X.sizes() << endl; cout << "y: " << y.sizes() << endl; auto clf = pywrap::STree(); - auto hyperparameters = json::parse("{\"C\": 0.7, \"max_iter\": 1e4, \"kernel\": \"rbf\"}"); - //clf.setHyperparameters(hyperparameters); + // auto stree = pywrap::STree(); + auto hyperparameters = json::parse("{\"C\": 0.7, \"max_iter\": 10000, \"kernel\": \"rbf\", \"random_state\": 17}"); + // stree.setHyperparameters(hyperparameters); cout << "STree Version: " << clf.version() << endl; auto svc = pywrap::SVC(); cout << "SVC with hyperparameters" << endl; - hyperparameters = json::parse("{\"kernel\": \"rbf\", \"C\": 0.7, \"random_state\": 17}"); - svc.setHyperparameters(hyperparameters); svc.fit(X, y, features, className, states); cout << "Graph: " << endl << clf.graph() << endl; clf.fit(X, y, features, className, states); + // stree.fit(X, y, features, className, states); auto prediction = clf.predict(X); cout << "Prediction: " << endl << "{"; for (int i = 0; i < prediction.size(0); ++i) { @@ -74,6 +74,7 @@ int main(int argc, char* argv[]) auto xg = pywrap::RandomForest(); xg.fit(X, y, features, className, states); cout << "STree Score ......: " << clf.score(X, y) << endl; + // cout << "STree hyper score : " << stree.score(X, y) << endl; cout << "RandomForest Score: " << rf.score(X, y) << endl; cout << "SVC Score ........: " << svc.score(X, y) << endl; cout << "XGBoost Score ....: " << xg.score(X, y) << endl;