From 296ed6b785b3e5d93e4961197e5e680e66e36945 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Wed, 1 Nov 2023 14:13:45 +0100 Subject: [PATCH] Adding tensor methods --- pspp.jnl | 2 + src/PyClassifier.cc | 13 ++- src/PyClassifier.h | 5 +- src/PyWrap.cc | 211 ++++++++++++++++++++++++-------------------- src/PyWrap.h | 12 ++- src/STree.cc | 5 -- src/STree.h | 1 - src/SVC.cc | 5 +- src/SVC.h | 2 +- 9 files changed, 144 insertions(+), 112 deletions(-) diff --git a/pspp.jnl b/pspp.jnl index cf047f4..3072ca7 100644 --- a/pspp.jnl +++ b/pspp.jnl @@ -1 +1,3 @@ GET FILE="/home/rmontanana/Code/covbench/data/covid_v9_20220630.sav". +SHOW SYSTEM. +SHOW SYSTEM. diff --git a/src/PyClassifier.cc b/src/PyClassifier.cc index 043ea5f..3012ae8 100644 --- a/src/PyClassifier.cc +++ b/src/PyClassifier.cc @@ -1,4 +1,5 @@ #include "PyClassifier.h" +#include namespace pywrap { @@ -13,13 +14,21 @@ namespace pywrap { pyWrap->clean(module, className); } - std::string PyClassifier::callMethod(const std::string& method) + std::string PyClassifier::version() + { + return pyWrap->version(module, className); + } + + std::string PyClassifier::callMethodString(const std::string& method) { return pyWrap->callMethodString(module, className, method); } PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states) { - + PyObject* Xp = NULL;//THPVariable_Wrap(X); + PyObject* yp = NULL;//THPVariable_Wrap(y); + pyWrap->fit(module, className, Xp, yp); + return *this; } } /* namespace PyWrap */ \ No newline at end of file diff --git a/src/PyClassifier.h b/src/PyClassifier.h index 13337ed..e23a0b2 100644 --- a/src/PyClassifier.h +++ b/src/PyClassifier.h @@ -12,7 +12,10 @@ namespace pywrap { PyClassifier(const std::string& module, const std::string& className); virtual ~PyClassifier(); PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states); - std::string callMethod(const std::string& method); + torch::Tensor predict(torch::Tensor& X); + double score(torch::Tensor& X, torch::Tensor& y); + std::string version(); + std::string callMethodString(const std::string& method); private: PyWrap* pyWrap; std::string module; diff --git a/src/PyWrap.cc b/src/PyWrap.cc index 5cc51ed..86ec928 100644 --- a/src/PyWrap.cc +++ b/src/PyWrap.cc @@ -83,34 +83,19 @@ namespace pywrap { PyErr_Print(); exit(1); } - template - T PyWrap::callMethod(const std::string& moduleName, const std::string& className, const std::string& method) + PyObject* PyWrap::getClass(const std::string& moduleName, const std::string& className) { - std::cout << "Llamando método" << std::endl; auto item = moduleClassMap.find({ moduleName, className }); if (item == moduleClassMap.end()) { errorAbort("Module " + moduleName + " and class " + className + " not found"); } std::cout << "Clase encontrada" << std::endl; - PyObject* instance = std::get<2>(item->second); - PyObject* result; - if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL))) - errorAbort("Couldn't call method " + method); - - T value = PyUnicode_AsUTF8(result); - std::cout << "Result: " << value << std::endl; - Py_DECREF(result); - return value; + return std::get<2>(item->second); } std::string PyWrap::callMethodString(const std::string& moduleName, const std::string& className, const std::string& method) { - std::cout << "Llamando método" << std::endl; - auto item = moduleClassMap.find({ moduleName, className }); - if (item == moduleClassMap.end()) { - errorAbort("Module " + moduleName + " and class " + className + " not found"); - } - std::cout << "Clase encontrada" << std::endl; - PyObject* instance = std::get<2>(item->second); + std::cout << "Llamando método " << method << std::endl; + PyObject* instance = getClass(moduleName, className); PyObject* result; if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL))) errorAbort("Couldn't call method " + method); @@ -120,83 +105,119 @@ namespace pywrap { Py_DECREF(result); return value; } - // void PyWrap::doCommand2() - // { - // PyObject* list = Py_BuildValue("[s]", "Stree"); - // // PyObject* module = PyImport_ImportModuleEx("stree", NULL, NULL, list); - // PyObject* module = PyImport_ImportModule("stree"); - // if (PyErr_Occurred()) { - // PyErr_Print(); - // cout << "Fails to obtain the module.\n"; - // return; - // } - // cout << "Antes de empezar" << endl; - // if (module != nullptr) { - // cout << "Lo consiguió!!!" << endl; - // // dict is a borrowed reference. - // auto pdict = PyModule_GetDict(module); - // if (pdict == nullptr) { - // cout << "Fails to get the dictionary.\n"; - // return; - // } - // Py_DECREF(module); - // PyObject* pKeys = PyDict_Keys(pdict); - // PyObject* pValues = PyDict_Values(pdict); - // map my_map; - // cout << "size: " << PyDict_Size(pdict) << endl; - // char* cstr_key = new char[100]; - // char* cstr_value = new char[500]; - // for (Py_ssize_t i = 0; i < PyDict_Size(pdict); ++i) { - // PyArg_Parse(PyList_GetItem(pKeys, i), "s", &cstr_key); - // PyArg_Parse(PyList_GetItem(pValues, i), "s", &cstr_value); - // //cout << cstr<< " "<< cstr2 < my_map; +// cout << "size: " << PyDict_Size(pdict) << endl; +// char* cstr_key = new char[100]; +// char* cstr_value = new char[500]; +// for (Py_ssize_t i = 0; i < PyDict_Size(pdict); ++i) { +// PyArg_Parse(PyList_GetItem(pKeys, i), "s", &cstr_key); +// PyArg_Parse(PyList_GetItem(pValues, i), "s", &cstr_value); +// //cout << cstr<< " "<< cstr2 < - T callMethod(const std::string& moduleName, const std::string& className, const std::string& method); - std::string callMethodString(const std::string& moduleName, const std::string& className, const std::string& method); // template T returnMethod(PyObject* result); // template std::string returnMethod(PyObject* result); // template int returnMethod(PyObject* result); @@ -29,11 +26,18 @@ namespace pywrap { // // at::Tensor& THPVariable_Unpack(PyObject * obj); // return THPVariable_Unpack(result); // }; - void importClass(const std::string& moduleName, const std::string& className); + // PyObject* callMethodArgs(const std::string& moduleName, const std::string& className, const std::string& method, PyObject* args); + void fit(const std::string& moduleName, const std::string& className, PyObject* X, PyObject* y); + PyObject* predict(const std::string& moduleName, const std::string& className, PyObject* X); + std::string callMethodString(const std::string& moduleName, const std::string& className, const std::string& method); + std::string version(const std::string& moduleName, const std::string& className); + double score(const std::string& moduleName, const std::string& className, PyObject* X, PyObject* y); void clean(const std::string& moduleName, const std::string& className); + void importClass(const std::string& moduleName, const std::string& className); // void doCommand2(); private: PyWrap(); + PyObject* getClass(const std::string& moduleName, const std::string& className); void errorAbort(const std::string& message); PyStatus initPython(); static PyWrap* wrapper; diff --git a/src/STree.cc b/src/STree.cc index e138601..45d766d 100644 --- a/src/STree.cc +++ b/src/STree.cc @@ -1,11 +1,6 @@ #include "STree.h" -#include namespace pywrap { - void STree::version() - { - std::cout << "Version: " << callMethod("version") << std::endl; - } } /* namespace pywrap */ \ No newline at end of file diff --git a/src/STree.h b/src/STree.h index 3484555..5299d6a 100644 --- a/src/STree.h +++ b/src/STree.h @@ -7,7 +7,6 @@ namespace pywrap { public: STree() : PyClassifier("stree", "Stree") {}; ~STree() = default; - void version(); private: }; diff --git a/src/SVC.cc b/src/SVC.cc index 2d432c6..993ad5f 100644 --- a/src/SVC.cc +++ b/src/SVC.cc @@ -1,11 +1,10 @@ #include "SVC.h" -#include namespace pywrap { - void SVC::version() + std::string SVC::version() { - std::cout << "repr_html: " << callMethod("_repr_html_") << std::endl; + return callMethodString("_repr_html_"); } } /* namespace pywrap */ \ No newline at end of file diff --git a/src/SVC.h b/src/SVC.h index e58d615..c84000a 100644 --- a/src/SVC.h +++ b/src/SVC.h @@ -7,7 +7,7 @@ namespace pywrap { public: SVC() : PyClassifier("sklearn.svm", "SVC") {}; ~SVC() = default; - void version(); + std::string version(); private: };