From 67e4b0af479740cc738a2a94ba54a8182deec417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Thu, 9 Nov 2023 12:56:01 +0100 Subject: [PATCH] PyWrap with built dictionary of arguments --- src/PyClassifier.cc | 4 +- src/PyWrap.cc | 114 +++++++++++++++++++++++++++++++++++++++----- src/PyWrap.h | 3 ++ src/SVC.cc | 7 +++ src/SVC.h | 1 + src/main.cc | 8 +++- 6 files changed, 120 insertions(+), 17 deletions(-) diff --git a/src/PyClassifier.cc b/src/PyClassifier.cc index a323927..69ee5aa 100644 --- a/src/PyClassifier.cc +++ b/src/PyClassifier.cc @@ -1,6 +1,5 @@ #include "PyClassifier.h" #include - namespace pywrap { namespace bp = boost::python; namespace np = boost::python::numpy; @@ -38,7 +37,8 @@ namespace pywrap { PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states) { if (!fitted && hyperparameters.size() > 0) { - std::cout << "Setting hyperparameters" << std::endl; + std::cout << "PyClassifier: Setting hyperparameters" << std::endl; + pyWrap->setHyperparameters(module, this->className, hyperparameters); } auto [Xn, yn] = tensors2numpy(X, y); CPyObject Xp = bp::incref(bp::object(Xn).ptr()); diff --git a/src/PyWrap.cc b/src/PyWrap.cc index 0c30326..7a4851c 100644 --- a/src/PyWrap.cc +++ b/src/PyWrap.cc @@ -3,6 +3,7 @@ #include "PyWrap.h" #include #include +#include #include namespace pywrap { @@ -101,9 +102,7 @@ namespace pywrap { errorAbort("Couldn't call method " + method); } catch (const std::exception& e) { - std::cerr << e.what() << '\n'; - RemoveInstance(); - exit(1); + errorAbort(e.what()); } std::string value = PyUnicode_AsUTF8(result); Py_XDECREF(result); @@ -113,6 +112,98 @@ namespace pywrap { { return callMethodString(moduleName, className, "version"); } + // void printPyObject(PyObject* obj) + // { + // PyObject* pStr = PyObject_Str(obj); + // const char* str = PyUnicode_AsUTF8(pStr); + // printf("%s\n", str); + // Py_XDECREF(pStr); + // } + + // void printDictionary(PyObject* pDict) + // { + // PyObject* pKeys = PyDict_Keys(pDict); + // Py_ssize_t size = PyList_Size(pKeys); + + // for (Py_ssize_t i = 0; i < size; ++i) { + // PyObject* pKey = PyList_GetItem(pKeys, i); + // PyObject* pValue = PyDict_GetItem(pDict, pKey); + + // printf("%s: ", PyUnicode_AsUTF8(pKey)); + // printPyObject(pValue); + // } + + // Py_XDECREF(pKeys); + // } + void cleanDictionary(PyObject* pDict) + { + PyObject* pKeys = PyDict_Keys(pDict); + Py_ssize_t size = PyList_Size(pKeys); + for (Py_ssize_t i = 0; i < size; ++i) { + PyObject* pKey = PyList_GetItem(pKeys, i); + PyObject* pValue = PyDict_GetItem(pDict, pKey); + Py_XDECREF(pKey); + Py_XDECREF(pValue); + } + Py_XDECREF(pKeys); + Py_XDECREF(pDict); + } + void PyWrap::setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters) + { + PyObject* args = PyDict_New(); + + // Build dictionary of arguments with a little help of chatGPT + std::cout << "Building dictionary of arguments" << std::endl; + try { + PyObject* pValue; + for (const auto& [key, value] : hyperparameters.items()) { + std::string type_name; + if (value.type_name() == "string") { + type_name = "s"; + pValue = Py_BuildValue("s", value.get().c_str()); + std::cout << key << " s " << value.get() << std::endl; + } else { + if (value.is_number_integer()) { + pValue = Py_BuildValue("i", value.get()); + std::cout << key << " i " << value.get() << std::endl; + } else { + pValue = Py_BuildValue("f", value.get()); + std::cout << key << " f " << value.get() << std::endl; + } + } + PyDict_SetItemString(args, key.c_str(), pValue); + Py_XDECREF(pValue); + } + } + catch (const std::exception& e) { + + Py_DECREF(args); + errorAbort(e.what()); + } + std::cout << "PyDict_Size=" << PyDict_Size(args) << std::endl; + std::cout << "Calling method set_args with" << std::endl; + //printDictionary(args); + Py_INCREF(args); + PyObject* result; + // Call the method with the argument dictionary with a little help of chatGPT + auto instance = getClass(moduleName, className); + try { + if (!(result = PyObject_CallMethod(instance, "set_params", "O", args))) { + //if (!(result = PyObject_Call(instance, PyObject_GetAttrString(instance, "set_params"), args, nullptr))) + std::cout << "Cleaning up because of error" << std::endl; + cleanDictionary(args); + errorAbort("Couldn't call method set_args"); + } + } + catch (const std::exception& e) { + std::cout << "Cleaning up because of exception" << std::endl; + cleanDictionary(args); + errorAbort(e.what()); + } + std::cout << "Cleaning up everything went ok!" << std::endl; + cleanDictionary(args); + Py_XDECREF(result); + } void PyWrap::fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y) { PyObject* instance = getClass(moduleName, className); @@ -123,10 +214,9 @@ namespace pywrap { errorAbort("Couldn't call method fit"); } catch (const std::exception& e) { - std::cerr << e.what() << '\n'; - RemoveInstance(); - exit(1); + errorAbort(e.what()); } + Py_XDECREF(result); } PyObject* PyWrap::predict(const std::string& moduleName, const std::string& className, CPyObject& X) @@ -139,9 +229,7 @@ namespace pywrap { errorAbort("Couldn't call method predict"); } catch (const std::exception& e) { - std::cerr << e.what() << '\n'; - RemoveInstance(); - exit(1); + errorAbort(e.what()); } Py_INCREF(result); return result; // Caller must free this object @@ -156,10 +244,10 @@ namespace pywrap { errorAbort("Couldn't call method score"); } catch (const std::exception& e) { - std::cerr << e.what() << '\n'; - RemoveInstance(); - exit(1); + errorAbort(e.what()); } - return PyFloat_AsDouble(result); + double resultValue = PyFloat_AsDouble(result); + Py_XDECREF(result); + return resultValue; } } \ No newline at end of file diff --git a/src/PyWrap.h b/src/PyWrap.h index 4b271e8..f122837 100644 --- a/src/PyWrap.h +++ b/src/PyWrap.h @@ -5,6 +5,7 @@ #include #include #include +#include #include "PyHelper.hpp" #pragma once @@ -13,6 +14,7 @@ namespace pywrap { /* Singleton class to handle Python/numpy interpreter. */ + using json = nlohmann::json; class PyWrap { public: PyWrap() = default; @@ -22,6 +24,7 @@ namespace pywrap { ~PyWrap() = default; std::string callMethodString(const std::string& moduleName, const std::string& className, const std::string& method); std::string version(const std::string& moduleName, const std::string& className); + void setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters); void fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y); PyObject* predict(const std::string& moduleName, const std::string& className, CPyObject& X); double score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y); diff --git a/src/SVC.cc b/src/SVC.cc index 9268c8b..225bf0a 100644 --- a/src/SVC.cc +++ b/src/SVC.cc @@ -5,4 +5,11 @@ namespace pywrap { { return callMethodString("1.0"); } + void SVC::setHyperparameters(const nlohmann::json& hyperparameters) + { + // Check if hyperparameters are valid + const std::vector validKeys = { "C", "gamma", "kernel", "random_state" }; + checkHyperparameters(validKeys, hyperparameters); + this->hyperparameters = hyperparameters; + } } /* namespace pywrap */ \ No newline at end of file diff --git a/src/SVC.h b/src/SVC.h index 41cfb52..1c755d9 100644 --- a/src/SVC.h +++ b/src/SVC.h @@ -8,6 +8,7 @@ namespace pywrap { SVC() : PyClassifier("sklearn.svm", "SVC") {}; ~SVC() = default; std::string version(); + void setHyperparameters(const nlohmann::json& hyperparameters) override; }; } /* namespace pywrap */ diff --git a/src/main.cc b/src/main.cc index 3f96720..8350575 100644 --- a/src/main.cc +++ b/src/main.cc @@ -43,6 +43,7 @@ tuple, string, map>> loadData int main(int argc, char* argv[]) { + using json = nlohmann::json; cout << "* Begin." << endl; { auto datasetName = "iris"; @@ -52,10 +53,13 @@ int main(int argc, char* argv[]) cout << "X: " << X.sizes() << endl; cout << "y: " << y.sizes() << endl; auto clf = pywrap::STree(); - auto hyperparameters = nlohmann::json({ "max_depth": 3, "C" : 0.7 }); - clf.setHyperparameters(hyperparameters); + auto hyperparameters = json::parse("{\"C\": 0.7, \"max_iter\": 1e4, \"kernel\": \"rbf\"}"); + //clf.setHyperparameters(hyperparameters); cout << "STree Version: " << clf.version() << endl; auto svc = pywrap::SVC(); + cout << "SVC with hyperparameters" << endl; + hyperparameters = json::parse("{\"kernel\": \"rbf\", \"C\": 0.7, \"random_state\": 17}"); + svc.setHyperparameters(hyperparameters); svc.fit(X, y, features, className, states); cout << "Graph: " << endl << clf.graph() << endl; clf.fit(X, y, features, className, states);