Begin templating
This commit is contained in:
@@ -1,15 +1,10 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
project(testcpy)
|
||||
|
||||
set( CMAKE_CXX_STANDARD 17)
|
||||
set( CMAKE_CXX_STANDARD 20)
|
||||
set( CMAKE_CXX_STANDARD_REQUIRED ON )
|
||||
|
||||
find_package(Python3 3.11...3.11.9 COMPONENTS Interpreter Development REQUIRED)
|
||||
|
||||
message("Python_FOUND:${Python3_FOUND}")
|
||||
message("Python_VERSION:${Python3_VERSION}")
|
||||
message("Python_Development_FOUND:${Python3_Development_FOUND}")
|
||||
message("Python_LIBRARIES:${Python3_LIBRARIES}")
|
||||
message("Python_INCLUDE_DIRS ${Python3_INCLUDE_DIRS}")
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
add_subdirectory(src)
|
@@ -3,4 +3,4 @@ include_directories(${Python3_INCLUDE_DIRS})
|
||||
|
||||
add_executable(main main.cc STree.cc SVC.cc PyClassifier.cc PyWrap.cc)
|
||||
|
||||
target_link_libraries(main ${Python3_LIBRARIES})
|
||||
target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}")
|
||||
|
@@ -13,9 +13,10 @@ namespace pywrap {
|
||||
pyWrap->clean(module, className);
|
||||
}
|
||||
|
||||
void PyClassifier::callMethod(const std::string& method)
|
||||
template<typename T>
|
||||
T PyClassifier::callMethod(const std::string& method)
|
||||
{
|
||||
pyWrap->callMethod(module, className, method);
|
||||
return pyWrap->callMethod<T>(module, className, method);
|
||||
}
|
||||
|
||||
} /* namespace PyWrap */
|
@@ -1,6 +1,9 @@
|
||||
#ifndef PYCLASSIFER_H
|
||||
#define PYCLASSIFER_H
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <torch/torch.h>
|
||||
#include "PyWrap.h"
|
||||
|
||||
namespace pywrap {
|
||||
@@ -8,7 +11,9 @@ namespace pywrap {
|
||||
public:
|
||||
PyClassifier(const std::string& module, const std::string& className);
|
||||
virtual ~PyClassifier();
|
||||
void callMethod(const std::string& method);
|
||||
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states);
|
||||
template <typename T>
|
||||
T callMethod(const std::string& method);
|
||||
private:
|
||||
PyWrap* pyWrap;
|
||||
std::string module;
|
||||
|
@@ -83,7 +83,8 @@ namespace pywrap {
|
||||
PyErr_Print();
|
||||
exit(1);
|
||||
}
|
||||
void PyWrap::callMethod(const std::string& moduleName, const std::string& className, const std::string& method)
|
||||
template<typename T>
|
||||
T PyWrap::callMethod(const std::string& moduleName, const std::string& className, const std::string& method)
|
||||
{
|
||||
std::cout << "Llamando método" << std::endl;
|
||||
auto item = moduleClassMap.find({ moduleName, className });
|
||||
@@ -95,8 +96,11 @@ namespace pywrap {
|
||||
PyObject* result;
|
||||
if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL)))
|
||||
errorAbort("Couldn't call method " + method);
|
||||
std::cout << "Result: " << PyUnicode_AsUTF8(result) << std::endl;
|
||||
|
||||
T value = PyUnicode_AsUTF8(result);
|
||||
std::cout << "Result: " << value << std::endl;
|
||||
Py_DECREF(result);
|
||||
return value;
|
||||
}
|
||||
// void PyWrap::doCommand2()
|
||||
// {
|
||||
|
13
src/PyWrap.h
13
src/PyWrap.h
@@ -16,7 +16,18 @@ namespace pywrap {
|
||||
static PyWrap* GetInstance();
|
||||
void operator=(const PyWrap&) = delete;
|
||||
~PyWrap();
|
||||
void callMethod(const std::string& moduleName, const std::string& className, const std::string& method);
|
||||
template<typename T>
|
||||
T callMethod(const std::string& moduleName, const std::string& className, const std::string& method);
|
||||
template<typename T> T returnMethod(PyObject* result);
|
||||
template<std::string> std::string returnMethod(PyObject* result);
|
||||
template<int> int returnMethod(PyObject* result);
|
||||
template<bool> bool returnMethod(PyObject* result);
|
||||
template<torch::Tensor> torch::Tensor returnMethod(PyObject* result)
|
||||
{
|
||||
// PyObject* THPVariable_Wrap(at::Tensor t);
|
||||
// at::Tensor& THPVariable_Unpack(PyObject * obj);
|
||||
return THPVariable_Unpack(result);
|
||||
};
|
||||
void importClass(const std::string& moduleName, const std::string& className);
|
||||
void clean(const std::string& moduleName, const std::string& className);
|
||||
// void doCommand2();
|
||||
|
@@ -1,10 +1,11 @@
|
||||
#include "STree.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace pywrap {
|
||||
|
||||
void STree::version()
|
||||
{
|
||||
callMethod("version");
|
||||
std::cout << "Version: " << callMethod<std::string>("version") << std::endl;
|
||||
}
|
||||
|
||||
} /* namespace pywrap */
|
@@ -1,10 +1,11 @@
|
||||
#include "SVC.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace pywrap {
|
||||
|
||||
void SVC::version()
|
||||
{
|
||||
callMethod("_repr_html_");
|
||||
//std::cout << "repr_html: " << callMethod<std::string>("_repr_html_") << std::endl;
|
||||
}
|
||||
|
||||
} /* namespace pywrap */
|
30
src/example.cpp
Normal file
30
src/example.cpp
Normal file
@@ -0,0 +1,30 @@
|
||||
#include<string>
|
||||
#include<iostream>
|
||||
|
||||
using namespace std;
|
||||
class Test {
|
||||
public:
|
||||
Test(const string& c) : c(c) {};
|
||||
~Test() { std::cout << "Destructor" << std::endl; };
|
||||
|
||||
template<typename T>
|
||||
T callMethod(const T& parameter)
|
||||
{
|
||||
std::cout << "Llamando a metodo" << std::endl;
|
||||
return parameter;
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
string c;
|
||||
|
||||
};
|
||||
|
||||
int main()
|
||||
{
|
||||
Test t("hola");
|
||||
cout << t.callMethod<string>("hola") << endl;
|
||||
cout << t.callMethod<int>(1) << endl;
|
||||
cout << t.callMethod<double>(7.3) << endl;
|
||||
return 0;
|
||||
}
|
Reference in New Issue
Block a user