Begin templating

This commit is contained in:
2023-10-30 22:45:35 +01:00
parent e26acc3676
commit 77c33942f6
9 changed files with 64 additions and 16 deletions

View File

@@ -1,15 +1,10 @@
cmake_minimum_required(VERSION 3.5)
project(testcpy)
set( CMAKE_CXX_STANDARD 17)
set( CMAKE_CXX_STANDARD 20)
set( CMAKE_CXX_STANDARD_REQUIRED ON )
find_package(Python3 3.11...3.11.9 COMPONENTS Interpreter Development REQUIRED)
message("Python_FOUND:${Python3_FOUND}")
message("Python_VERSION:${Python3_VERSION}")
message("Python_Development_FOUND:${Python3_Development_FOUND}")
message("Python_LIBRARIES:${Python3_LIBRARIES}")
message("Python_INCLUDE_DIRS ${Python3_INCLUDE_DIRS}")
find_package(Torch REQUIRED)
add_subdirectory(src)

View File

@@ -3,4 +3,4 @@ include_directories(${Python3_INCLUDE_DIRS})
add_executable(main main.cc STree.cc SVC.cc PyClassifier.cc PyWrap.cc)
target_link_libraries(main ${Python3_LIBRARIES})
target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}")

View File

@@ -13,9 +13,10 @@ namespace pywrap {
pyWrap->clean(module, className);
}
void PyClassifier::callMethod(const std::string& method)
template<typename T>
T PyClassifier::callMethod(const std::string& method)
{
pyWrap->callMethod(module, className, method);
return pyWrap->callMethod<T>(module, className, method);
}
} /* namespace PyWrap */

View File

@@ -1,6 +1,9 @@
#ifndef PYCLASSIFER_H
#define PYCLASSIFER_H
#include <string>
#include <map>
#include <vector>
#include <torch/torch.h>
#include "PyWrap.h"
namespace pywrap {
@@ -8,7 +11,9 @@ namespace pywrap {
public:
PyClassifier(const std::string& module, const std::string& className);
virtual ~PyClassifier();
void callMethod(const std::string& method);
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states);
template <typename T>
T callMethod(const std::string& method);
private:
PyWrap* pyWrap;
std::string module;

View File

@@ -83,7 +83,8 @@ namespace pywrap {
PyErr_Print();
exit(1);
}
void PyWrap::callMethod(const std::string& moduleName, const std::string& className, const std::string& method)
template<typename T>
T PyWrap::callMethod(const std::string& moduleName, const std::string& className, const std::string& method)
{
std::cout << "Llamando método" << std::endl;
auto item = moduleClassMap.find({ moduleName, className });
@@ -95,8 +96,11 @@ namespace pywrap {
PyObject* result;
if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL)))
errorAbort("Couldn't call method " + method);
std::cout << "Result: " << PyUnicode_AsUTF8(result) << std::endl;
T value = PyUnicode_AsUTF8(result);
std::cout << "Result: " << value << std::endl;
Py_DECREF(result);
return value;
}
// void PyWrap::doCommand2()
// {

View File

@@ -16,7 +16,18 @@ namespace pywrap {
static PyWrap* GetInstance();
void operator=(const PyWrap&) = delete;
~PyWrap();
void callMethod(const std::string& moduleName, const std::string& className, const std::string& method);
template<typename T>
T callMethod(const std::string& moduleName, const std::string& className, const std::string& method);
template<typename T> T returnMethod(PyObject* result);
template<std::string> std::string returnMethod(PyObject* result);
template<int> int returnMethod(PyObject* result);
template<bool> bool returnMethod(PyObject* result);
template<torch::Tensor> torch::Tensor returnMethod(PyObject* result)
{
// PyObject* THPVariable_Wrap(at::Tensor t);
// at::Tensor& THPVariable_Unpack(PyObject * obj);
return THPVariable_Unpack(result);
};
void importClass(const std::string& moduleName, const std::string& className);
void clean(const std::string& moduleName, const std::string& className);
// void doCommand2();

View File

@@ -1,10 +1,11 @@
#include "STree.h"
#include <iostream>
namespace pywrap {
void STree::version()
{
callMethod("version");
std::cout << "Version: " << callMethod<std::string>("version") << std::endl;
}
} /* namespace pywrap */

View File

@@ -1,10 +1,11 @@
#include "SVC.h"
#include <iostream>
namespace pywrap {
void SVC::version()
{
callMethod("_repr_html_");
//std::cout << "repr_html: " << callMethod<std::string>("_repr_html_") << std::endl;
}
} /* namespace pywrap */

30
src/example.cpp Normal file
View File

@@ -0,0 +1,30 @@
#include<string>
#include<iostream>
using namespace std;
class Test {
public:
Test(const string& c) : c(c) {};
~Test() { std::cout << "Destructor" << std::endl; };
template<typename T>
T callMethod(const T& parameter)
{
std::cout << "Llamando a metodo" << std::endl;
return parameter;
}
private:
string c;
};
int main()
{
Test t("hola");
cout << t.callMethod<string>("hola") << endl;
cout << t.callMethod<int>(1) << endl;
cout << t.callMethod<double>(7.3) << endl;
return 0;
}