Add torch library

This commit is contained in:
2023-10-31 10:07:24 +01:00
parent 77c33942f6
commit cb3281ed91
16 changed files with 884 additions and 27 deletions

View File

@@ -1,6 +1,8 @@
include_directories(${PyWrap_SOURCE_DIR}/lib/Files)
include_directories(${Python3_INCLUDE_DIRS})
add_executable(main main.cc STree.cc SVC.cc PyClassifier.cc PyWrap.cc)
add_executable(example example.cpp)
target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}")
target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ArffFiles)
target_link_libraries(example "${TORCH_LIBRARIES}" ArffFiles)

View File

@@ -13,10 +13,13 @@ namespace pywrap {
pyWrap->clean(module, className);
}
template<typename T>
T PyClassifier::callMethod(const std::string& method)
std::string PyClassifier::callMethod(const std::string& method)
{
return pyWrap->callMethod<T>(module, className, method);
return pyWrap->callMethodString(module, className, method);
}
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
{
}
} /* namespace PyWrap */

View File

@@ -12,8 +12,7 @@ namespace pywrap {
PyClassifier(const std::string& module, const std::string& className);
virtual ~PyClassifier();
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states);
template <typename T>
T callMethod(const std::string& method);
std::string callMethod(const std::string& method);
private:
PyWrap* pyWrap;
std::string module;

View File

@@ -102,6 +102,24 @@ namespace pywrap {
Py_DECREF(result);
return value;
}
std::string PyWrap::callMethodString(const std::string& moduleName, const std::string& className, const std::string& method)
{
std::cout << "Llamando método" << std::endl;
auto item = moduleClassMap.find({ moduleName, className });
if (item == moduleClassMap.end()) {
errorAbort("Module " + moduleName + " and class " + className + " not found");
}
std::cout << "Clase encontrada" << std::endl;
PyObject* instance = std::get<2>(item->second);
PyObject* result;
if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL)))
errorAbort("Couldn't call method " + method);
std::string value = PyUnicode_AsUTF8(result);
std::cout << "Result: " << value << std::endl;
Py_DECREF(result);
return value;
}
// void PyWrap::doCommand2()
// {
// PyObject* list = Py_BuildValue("[s]", "Stree");

View File

@@ -18,16 +18,17 @@ namespace pywrap {
~PyWrap();
template<typename T>
T callMethod(const std::string& moduleName, const std::string& className, const std::string& method);
template<typename T> T returnMethod(PyObject* result);
template<std::string> std::string returnMethod(PyObject* result);
template<int> int returnMethod(PyObject* result);
template<bool> bool returnMethod(PyObject* result);
template<torch::Tensor> torch::Tensor returnMethod(PyObject* result)
{
// PyObject* THPVariable_Wrap(at::Tensor t);
// at::Tensor& THPVariable_Unpack(PyObject * obj);
return THPVariable_Unpack(result);
};
std::string callMethodString(const std::string& moduleName, const std::string& className, const std::string& method);
// template<typename T> T returnMethod(PyObject* result);
// template<std::string> std::string returnMethod(PyObject* result);
// template<int> int returnMethod(PyObject* result);
// template<bool> bool returnMethod(PyObject* result);
// template<torch::Tensor> torch::Tensor returnMethod(PyObject* result)
// {
// // PyObject* THPVariable_Wrap(at::Tensor t);
// // at::Tensor& THPVariable_Unpack(PyObject * obj);
// return THPVariable_Unpack(result);
// };
void importClass(const std::string& moduleName, const std::string& className);
void clean(const std::string& moduleName, const std::string& className);
// void doCommand2();

View File

@@ -5,7 +5,7 @@ namespace pywrap {
void STree::version()
{
std::cout << "Version: " << callMethod<std::string>("version") << std::endl;
std::cout << "Version: " << callMethod("version") << std::endl;
}
} /* namespace pywrap */

View File

@@ -5,7 +5,7 @@ namespace pywrap {
void SVC::version()
{
//std::cout << "repr_html: " << callMethod<std::string>("_repr_html_") << std::endl;
std::cout << "repr_html: " << callMethod("_repr_html_") << std::endl;
}
} /* namespace pywrap */

View File

@@ -1,7 +1,10 @@
#include <torch/torch.h>
#include "ArffFiles.h"
#include<string>
#include<iostream>
using namespace std;
using namespace torch;
class Test {
public:
Test(const string& c) : c(c) {};
@@ -13,18 +16,45 @@ public:
std::cout << "Llamando a metodo" << std::endl;
return parameter;
}
private:
string c;
};
tuple<Tensor, Tensor, vector<string>, string, map<string, vector<int>>> loadDataset(const string& name, bool class_last)
{
auto handler = ArffFiles();
handler.load(static_cast<string>(name) + ".arff", class_last);
// Get Dataset X, y
vector<vector<float>> X = handler.getX();
vector<int> y = handler.getY();
// // Get className & Features
auto className = handler.getClassName();
vector<string> features;
auto attributes = handler.getAttributes();
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; });
torch::Tensor Xd;
auto states = map<string, vector<int>>();
auto yt = torch::tensor(y, torch::kInt32);
Xd = torch::zeros({ static_cast<int>(X.size()), static_cast<int>(X[0].size()) }, torch::kFloat32);
for (int i = 0; i < features.size(); ++i) {
Xd.index_put_({ i, "..." }, torch::tensor(X[i], torch::kFloat32));
}
return make_tuple(Xd, yt, features, className, states);
}
int main()
{
Test t("hola");
cout << t.callMethod<string>("hola") << endl;
cout << t.callMethod<int>(1) << endl;
cout << t.callMethod<double>(7.3) << endl;
vector<vector<float>> X;
vector<int> y = { 1, 2, 3 };
X.push_back({ 1.1, 2.2, 3.3 });
vector<float> v = { 1.1, 2.2, 3.3 };
torch::Tensor matrix = torch::tensor(X[0], torch::kFloat32);
cout << "X:" << matrix << endl;
cout << "y:" << torch::tensor(y, torch::kInt32) << endl;
return 0;
}

View File

@@ -1,10 +1,48 @@
#include <torch/torch.h>
#include "ArffFiles.h"
#include <vector>
#include <string>
#include <iostream>
#include <map>
#include <tuple>
#include "STree.h"
#include "SVC.h"
using namespace std;
using namespace torch;
class Paths {
public:
static string datasets()
{
return "/home/rmontanana/Code/discretizbench/datasets/";
}
};
tuple<Tensor, Tensor, vector<string>, string, map<string, vector<int>>> loadDataset(const string& name, bool class_last)
{
auto handler = ArffFiles();
handler.load(Paths::datasets() + static_cast<string>(name) + ".arff", class_last);
// Get Dataset X, y
vector<vector<float>> X = handler.getX();
vector<int> y = handler.getY();
// Get className & Features
auto className = handler.getClassName();
vector<string> features;
auto attributes = handler.getAttributes();
transform(attributes.begin(), attributes.end(), back_inserter(features), [](const auto& pair) { return pair.first; });
Tensor Xd;
auto states = map<string, vector<int>>();
Xd = torch::zeros({ static_cast<int>(X.size()), static_cast<int>(X[0].size()) }, torch::kFloat32);
for (int i = 0; i < features.size(); ++i) {
Xd.index_put_({ i, "..." }, torch::tensor(X[i], torch::kFloat32));
}
return { Xd, torch::tensor(y, torch::kInt32), features, className, states };
}
int main(int argc, char* argv[])
{
// auto wrap = pywrap::PyWrap("stree", "Stree");
// wrap.callMethod("version");
auto [X, y, features, className, states] = loadDataset("iris", true);
auto stree = pywrap::STree();
stree.version();
auto svc = pywrap::SVC();