PyWrap with built dictionary of arguments
This commit is contained in:
114
src/PyWrap.cc
114
src/PyWrap.cc
@@ -3,6 +3,7 @@
|
||||
#include "PyWrap.h"
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <iostream>
|
||||
#include <boost/python/numpy.hpp>
|
||||
|
||||
namespace pywrap {
|
||||
@@ -101,9 +102,7 @@ namespace pywrap {
|
||||
errorAbort("Couldn't call method " + method);
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::cerr << e.what() << '\n';
|
||||
RemoveInstance();
|
||||
exit(1);
|
||||
errorAbort(e.what());
|
||||
}
|
||||
std::string value = PyUnicode_AsUTF8(result);
|
||||
Py_XDECREF(result);
|
||||
@@ -113,6 +112,98 @@ namespace pywrap {
|
||||
{
|
||||
return callMethodString(moduleName, className, "version");
|
||||
}
|
||||
// void printPyObject(PyObject* obj)
|
||||
// {
|
||||
// PyObject* pStr = PyObject_Str(obj);
|
||||
// const char* str = PyUnicode_AsUTF8(pStr);
|
||||
// printf("%s\n", str);
|
||||
// Py_XDECREF(pStr);
|
||||
// }
|
||||
|
||||
// void printDictionary(PyObject* pDict)
|
||||
// {
|
||||
// PyObject* pKeys = PyDict_Keys(pDict);
|
||||
// Py_ssize_t size = PyList_Size(pKeys);
|
||||
|
||||
// for (Py_ssize_t i = 0; i < size; ++i) {
|
||||
// PyObject* pKey = PyList_GetItem(pKeys, i);
|
||||
// PyObject* pValue = PyDict_GetItem(pDict, pKey);
|
||||
|
||||
// printf("%s: ", PyUnicode_AsUTF8(pKey));
|
||||
// printPyObject(pValue);
|
||||
// }
|
||||
|
||||
// Py_XDECREF(pKeys);
|
||||
// }
|
||||
void cleanDictionary(PyObject* pDict)
|
||||
{
|
||||
PyObject* pKeys = PyDict_Keys(pDict);
|
||||
Py_ssize_t size = PyList_Size(pKeys);
|
||||
for (Py_ssize_t i = 0; i < size; ++i) {
|
||||
PyObject* pKey = PyList_GetItem(pKeys, i);
|
||||
PyObject* pValue = PyDict_GetItem(pDict, pKey);
|
||||
Py_XDECREF(pKey);
|
||||
Py_XDECREF(pValue);
|
||||
}
|
||||
Py_XDECREF(pKeys);
|
||||
Py_XDECREF(pDict);
|
||||
}
|
||||
void PyWrap::setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters)
|
||||
{
|
||||
PyObject* args = PyDict_New();
|
||||
|
||||
// Build dictionary of arguments with a little help of chatGPT
|
||||
std::cout << "Building dictionary of arguments" << std::endl;
|
||||
try {
|
||||
PyObject* pValue;
|
||||
for (const auto& [key, value] : hyperparameters.items()) {
|
||||
std::string type_name;
|
||||
if (value.type_name() == "string") {
|
||||
type_name = "s";
|
||||
pValue = Py_BuildValue("s", value.get<std::string>().c_str());
|
||||
std::cout << key << " s " << value.get<std::string>() << std::endl;
|
||||
} else {
|
||||
if (value.is_number_integer()) {
|
||||
pValue = Py_BuildValue("i", value.get<int>());
|
||||
std::cout << key << " i " << value.get<int>() << std::endl;
|
||||
} else {
|
||||
pValue = Py_BuildValue("f", value.get<double>());
|
||||
std::cout << key << " f " << value.get<double>() << std::endl;
|
||||
}
|
||||
}
|
||||
PyDict_SetItemString(args, key.c_str(), pValue);
|
||||
Py_XDECREF(pValue);
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
|
||||
Py_DECREF(args);
|
||||
errorAbort(e.what());
|
||||
}
|
||||
std::cout << "PyDict_Size=" << PyDict_Size(args) << std::endl;
|
||||
std::cout << "Calling method set_args with" << std::endl;
|
||||
//printDictionary(args);
|
||||
Py_INCREF(args);
|
||||
PyObject* result;
|
||||
// Call the method with the argument dictionary with a little help of chatGPT
|
||||
auto instance = getClass(moduleName, className);
|
||||
try {
|
||||
if (!(result = PyObject_CallMethod(instance, "set_params", "O", args))) {
|
||||
//if (!(result = PyObject_Call(instance, PyObject_GetAttrString(instance, "set_params"), args, nullptr)))
|
||||
std::cout << "Cleaning up because of error" << std::endl;
|
||||
cleanDictionary(args);
|
||||
errorAbort("Couldn't call method set_args");
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::cout << "Cleaning up because of exception" << std::endl;
|
||||
cleanDictionary(args);
|
||||
errorAbort(e.what());
|
||||
}
|
||||
std::cout << "Cleaning up everything went ok!" << std::endl;
|
||||
cleanDictionary(args);
|
||||
Py_XDECREF(result);
|
||||
}
|
||||
void PyWrap::fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
|
||||
{
|
||||
PyObject* instance = getClass(moduleName, className);
|
||||
@@ -123,10 +214,9 @@ namespace pywrap {
|
||||
errorAbort("Couldn't call method fit");
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::cerr << e.what() << '\n';
|
||||
RemoveInstance();
|
||||
exit(1);
|
||||
errorAbort(e.what());
|
||||
}
|
||||
Py_XDECREF(result);
|
||||
}
|
||||
|
||||
PyObject* PyWrap::predict(const std::string& moduleName, const std::string& className, CPyObject& X)
|
||||
@@ -139,9 +229,7 @@ namespace pywrap {
|
||||
errorAbort("Couldn't call method predict");
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::cerr << e.what() << '\n';
|
||||
RemoveInstance();
|
||||
exit(1);
|
||||
errorAbort(e.what());
|
||||
}
|
||||
Py_INCREF(result);
|
||||
return result; // Caller must free this object
|
||||
@@ -156,10 +244,10 @@ namespace pywrap {
|
||||
errorAbort("Couldn't call method score");
|
||||
}
|
||||
catch (const std::exception& e) {
|
||||
std::cerr << e.what() << '\n';
|
||||
RemoveInstance();
|
||||
exit(1);
|
||||
errorAbort(e.what());
|
||||
}
|
||||
return PyFloat_AsDouble(result);
|
||||
double resultValue = PyFloat_AsDouble(result);
|
||||
Py_XDECREF(result);
|
||||
return resultValue;
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user