From 77c33942f69eb52bf6ae3a2659bea4be3a7b1720 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 30 Oct 2023 22:45:35 +0100 Subject: [PATCH] Begin templating --- CMakeLists.txt | 9 ++------- src/CMakeLists.txt | 2 +- src/PyClassifier.cc | 5 +++-- src/PyClassifier.h | 7 ++++++- src/PyWrap.cc | 8 ++++++-- src/PyWrap.h | 13 ++++++++++++- src/STree.cc | 3 ++- src/SVC.cc | 3 ++- src/example.cpp | 30 ++++++++++++++++++++++++++++++ 9 files changed, 64 insertions(+), 16 deletions(-) create mode 100644 src/example.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index ed98400..64884e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,15 +1,10 @@ cmake_minimum_required(VERSION 3.5) project(testcpy) -set( CMAKE_CXX_STANDARD 17) +set( CMAKE_CXX_STANDARD 20) set( CMAKE_CXX_STANDARD_REQUIRED ON ) find_package(Python3 3.11...3.11.9 COMPONENTS Interpreter Development REQUIRED) - -message("Python_FOUND:${Python3_FOUND}") -message("Python_VERSION:${Python3_VERSION}") -message("Python_Development_FOUND:${Python3_Development_FOUND}") -message("Python_LIBRARIES:${Python3_LIBRARIES}") -message("Python_INCLUDE_DIRS ${Python3_INCLUDE_DIRS}") +find_package(Torch REQUIRED) add_subdirectory(src) \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 33ace1e..c93ed54 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,4 +3,4 @@ include_directories(${Python3_INCLUDE_DIRS}) add_executable(main main.cc STree.cc SVC.cc PyClassifier.cc PyWrap.cc) -target_link_libraries(main ${Python3_LIBRARIES}) +target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}") diff --git a/src/PyClassifier.cc b/src/PyClassifier.cc index b72f19b..36cc41a 100644 --- a/src/PyClassifier.cc +++ b/src/PyClassifier.cc @@ -13,9 +13,10 @@ namespace pywrap { pyWrap->clean(module, className); } - void PyClassifier::callMethod(const std::string& method) + template + T PyClassifier::callMethod(const std::string& method) { - pyWrap->callMethod(module, className, method); + return pyWrap->callMethod(module, className, method); } } /* namespace PyWrap */ \ No newline at end of file diff --git a/src/PyClassifier.h b/src/PyClassifier.h index 3610033..a057664 100644 --- a/src/PyClassifier.h +++ b/src/PyClassifier.h @@ -1,6 +1,9 @@ #ifndef PYCLASSIFER_H #define PYCLASSIFER_H #include +#include +#include +#include #include "PyWrap.h" namespace pywrap { @@ -8,7 +11,9 @@ namespace pywrap { public: PyClassifier(const std::string& module, const std::string& className); virtual ~PyClassifier(); - void callMethod(const std::string& method); + PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states); + template + T callMethod(const std::string& method); private: PyWrap* pyWrap; std::string module; diff --git a/src/PyWrap.cc b/src/PyWrap.cc index 9dca6c7..d711705 100644 --- a/src/PyWrap.cc +++ b/src/PyWrap.cc @@ -83,7 +83,8 @@ namespace pywrap { PyErr_Print(); exit(1); } - void PyWrap::callMethod(const std::string& moduleName, const std::string& className, const std::string& method) + template + T PyWrap::callMethod(const std::string& moduleName, const std::string& className, const std::string& method) { std::cout << "Llamando método" << std::endl; auto item = moduleClassMap.find({ moduleName, className }); @@ -95,8 +96,11 @@ namespace pywrap { PyObject* result; if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL))) errorAbort("Couldn't call method " + method); - std::cout << "Result: " << PyUnicode_AsUTF8(result) << std::endl; + + T value = PyUnicode_AsUTF8(result); + std::cout << "Result: " << value << std::endl; Py_DECREF(result); + return value; } // void PyWrap::doCommand2() // { diff --git a/src/PyWrap.h b/src/PyWrap.h index cad4c83..c5d4537 100644 --- a/src/PyWrap.h +++ b/src/PyWrap.h @@ -16,7 +16,18 @@ namespace pywrap { static PyWrap* GetInstance(); void operator=(const PyWrap&) = delete; ~PyWrap(); - void callMethod(const std::string& moduleName, const std::string& className, const std::string& method); + template + T callMethod(const std::string& moduleName, const std::string& className, const std::string& method); + template T returnMethod(PyObject* result); + template std::string returnMethod(PyObject* result); + template int returnMethod(PyObject* result); + template bool returnMethod(PyObject* result); + template torch::Tensor returnMethod(PyObject* result) + { + // PyObject* THPVariable_Wrap(at::Tensor t); + // at::Tensor& THPVariable_Unpack(PyObject * obj); + return THPVariable_Unpack(result); + }; void importClass(const std::string& moduleName, const std::string& className); void clean(const std::string& moduleName, const std::string& className); // void doCommand2(); diff --git a/src/STree.cc b/src/STree.cc index 3db97a7..669b15f 100644 --- a/src/STree.cc +++ b/src/STree.cc @@ -1,10 +1,11 @@ #include "STree.h" +#include namespace pywrap { void STree::version() { - callMethod("version"); + std::cout << "Version: " << callMethod("version") << std::endl; } } /* namespace pywrap */ \ No newline at end of file diff --git a/src/SVC.cc b/src/SVC.cc index 2d9dcb9..d247c06 100644 --- a/src/SVC.cc +++ b/src/SVC.cc @@ -1,10 +1,11 @@ #include "SVC.h" +#include namespace pywrap { void SVC::version() { - callMethod("_repr_html_"); + //std::cout << "repr_html: " << callMethod("_repr_html_") << std::endl; } } /* namespace pywrap */ \ No newline at end of file diff --git a/src/example.cpp b/src/example.cpp new file mode 100644 index 0000000..e738315 --- /dev/null +++ b/src/example.cpp @@ -0,0 +1,30 @@ +#include +#include + +using namespace std; +class Test { +public: + Test(const string& c) : c(c) {}; + ~Test() { std::cout << "Destructor" << std::endl; }; + + template + T callMethod(const T& parameter) + { + std::cout << "Llamando a metodo" << std::endl; + return parameter; + } + + +private: + string c; + +}; + +int main() +{ + Test t("hola"); + cout << t.callMethod("hola") << endl; + cout << t.callMethod(1) << endl; + cout << t.callMethod(7.3) << endl; + return 0; +} \ No newline at end of file