PyWrap with built dictionary of arguments
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
#include "PyClassifier.h"
|
#include "PyClassifier.h"
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
namespace pywrap {
|
namespace pywrap {
|
||||||
namespace bp = boost::python;
|
namespace bp = boost::python;
|
||||||
namespace np = boost::python::numpy;
|
namespace np = boost::python::numpy;
|
||||||
@@ -38,7 +37,8 @@ namespace pywrap {
|
|||||||
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
|
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
|
||||||
{
|
{
|
||||||
if (!fitted && hyperparameters.size() > 0) {
|
if (!fitted && hyperparameters.size() > 0) {
|
||||||
std::cout << "Setting hyperparameters" << std::endl;
|
std::cout << "PyClassifier: Setting hyperparameters" << std::endl;
|
||||||
|
pyWrap->setHyperparameters(module, this->className, hyperparameters);
|
||||||
}
|
}
|
||||||
auto [Xn, yn] = tensors2numpy(X, y);
|
auto [Xn, yn] = tensors2numpy(X, y);
|
||||||
CPyObject Xp = bp::incref(bp::object(Xn).ptr());
|
CPyObject Xp = bp::incref(bp::object(Xn).ptr());
|
||||||
|
114
src/PyWrap.cc
114
src/PyWrap.cc
@@ -3,6 +3,7 @@
|
|||||||
#include "PyWrap.h"
|
#include "PyWrap.h"
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <iostream>
|
||||||
#include <boost/python/numpy.hpp>
|
#include <boost/python/numpy.hpp>
|
||||||
|
|
||||||
namespace pywrap {
|
namespace pywrap {
|
||||||
@@ -101,9 +102,7 @@ namespace pywrap {
|
|||||||
errorAbort("Couldn't call method " + method);
|
errorAbort("Couldn't call method " + method);
|
||||||
}
|
}
|
||||||
catch (const std::exception& e) {
|
catch (const std::exception& e) {
|
||||||
std::cerr << e.what() << '\n';
|
errorAbort(e.what());
|
||||||
RemoveInstance();
|
|
||||||
exit(1);
|
|
||||||
}
|
}
|
||||||
std::string value = PyUnicode_AsUTF8(result);
|
std::string value = PyUnicode_AsUTF8(result);
|
||||||
Py_XDECREF(result);
|
Py_XDECREF(result);
|
||||||
@@ -113,6 +112,98 @@ namespace pywrap {
|
|||||||
{
|
{
|
||||||
return callMethodString(moduleName, className, "version");
|
return callMethodString(moduleName, className, "version");
|
||||||
}
|
}
|
||||||
|
// void printPyObject(PyObject* obj)
|
||||||
|
// {
|
||||||
|
// PyObject* pStr = PyObject_Str(obj);
|
||||||
|
// const char* str = PyUnicode_AsUTF8(pStr);
|
||||||
|
// printf("%s\n", str);
|
||||||
|
// Py_XDECREF(pStr);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// void printDictionary(PyObject* pDict)
|
||||||
|
// {
|
||||||
|
// PyObject* pKeys = PyDict_Keys(pDict);
|
||||||
|
// Py_ssize_t size = PyList_Size(pKeys);
|
||||||
|
|
||||||
|
// for (Py_ssize_t i = 0; i < size; ++i) {
|
||||||
|
// PyObject* pKey = PyList_GetItem(pKeys, i);
|
||||||
|
// PyObject* pValue = PyDict_GetItem(pDict, pKey);
|
||||||
|
|
||||||
|
// printf("%s: ", PyUnicode_AsUTF8(pKey));
|
||||||
|
// printPyObject(pValue);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Py_XDECREF(pKeys);
|
||||||
|
// }
|
||||||
|
void cleanDictionary(PyObject* pDict)
|
||||||
|
{
|
||||||
|
PyObject* pKeys = PyDict_Keys(pDict);
|
||||||
|
Py_ssize_t size = PyList_Size(pKeys);
|
||||||
|
for (Py_ssize_t i = 0; i < size; ++i) {
|
||||||
|
PyObject* pKey = PyList_GetItem(pKeys, i);
|
||||||
|
PyObject* pValue = PyDict_GetItem(pDict, pKey);
|
||||||
|
Py_XDECREF(pKey);
|
||||||
|
Py_XDECREF(pValue);
|
||||||
|
}
|
||||||
|
Py_XDECREF(pKeys);
|
||||||
|
Py_XDECREF(pDict);
|
||||||
|
}
|
||||||
|
void PyWrap::setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters)
|
||||||
|
{
|
||||||
|
PyObject* args = PyDict_New();
|
||||||
|
|
||||||
|
// Build dictionary of arguments with a little help of chatGPT
|
||||||
|
std::cout << "Building dictionary of arguments" << std::endl;
|
||||||
|
try {
|
||||||
|
PyObject* pValue;
|
||||||
|
for (const auto& [key, value] : hyperparameters.items()) {
|
||||||
|
std::string type_name;
|
||||||
|
if (value.type_name() == "string") {
|
||||||
|
type_name = "s";
|
||||||
|
pValue = Py_BuildValue("s", value.get<std::string>().c_str());
|
||||||
|
std::cout << key << " s " << value.get<std::string>() << std::endl;
|
||||||
|
} else {
|
||||||
|
if (value.is_number_integer()) {
|
||||||
|
pValue = Py_BuildValue("i", value.get<int>());
|
||||||
|
std::cout << key << " i " << value.get<int>() << std::endl;
|
||||||
|
} else {
|
||||||
|
pValue = Py_BuildValue("f", value.get<double>());
|
||||||
|
std::cout << key << " f " << value.get<double>() << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
PyDict_SetItemString(args, key.c_str(), pValue);
|
||||||
|
Py_XDECREF(pValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (const std::exception& e) {
|
||||||
|
|
||||||
|
Py_DECREF(args);
|
||||||
|
errorAbort(e.what());
|
||||||
|
}
|
||||||
|
std::cout << "PyDict_Size=" << PyDict_Size(args) << std::endl;
|
||||||
|
std::cout << "Calling method set_args with" << std::endl;
|
||||||
|
//printDictionary(args);
|
||||||
|
Py_INCREF(args);
|
||||||
|
PyObject* result;
|
||||||
|
// Call the method with the argument dictionary with a little help of chatGPT
|
||||||
|
auto instance = getClass(moduleName, className);
|
||||||
|
try {
|
||||||
|
if (!(result = PyObject_CallMethod(instance, "set_params", "O", args))) {
|
||||||
|
//if (!(result = PyObject_Call(instance, PyObject_GetAttrString(instance, "set_params"), args, nullptr)))
|
||||||
|
std::cout << "Cleaning up because of error" << std::endl;
|
||||||
|
cleanDictionary(args);
|
||||||
|
errorAbort("Couldn't call method set_args");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (const std::exception& e) {
|
||||||
|
std::cout << "Cleaning up because of exception" << std::endl;
|
||||||
|
cleanDictionary(args);
|
||||||
|
errorAbort(e.what());
|
||||||
|
}
|
||||||
|
std::cout << "Cleaning up everything went ok!" << std::endl;
|
||||||
|
cleanDictionary(args);
|
||||||
|
Py_XDECREF(result);
|
||||||
|
}
|
||||||
void PyWrap::fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
|
void PyWrap::fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
|
||||||
{
|
{
|
||||||
PyObject* instance = getClass(moduleName, className);
|
PyObject* instance = getClass(moduleName, className);
|
||||||
@@ -123,10 +214,9 @@ namespace pywrap {
|
|||||||
errorAbort("Couldn't call method fit");
|
errorAbort("Couldn't call method fit");
|
||||||
}
|
}
|
||||||
catch (const std::exception& e) {
|
catch (const std::exception& e) {
|
||||||
std::cerr << e.what() << '\n';
|
errorAbort(e.what());
|
||||||
RemoveInstance();
|
|
||||||
exit(1);
|
|
||||||
}
|
}
|
||||||
|
Py_XDECREF(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* PyWrap::predict(const std::string& moduleName, const std::string& className, CPyObject& X)
|
PyObject* PyWrap::predict(const std::string& moduleName, const std::string& className, CPyObject& X)
|
||||||
@@ -139,9 +229,7 @@ namespace pywrap {
|
|||||||
errorAbort("Couldn't call method predict");
|
errorAbort("Couldn't call method predict");
|
||||||
}
|
}
|
||||||
catch (const std::exception& e) {
|
catch (const std::exception& e) {
|
||||||
std::cerr << e.what() << '\n';
|
errorAbort(e.what());
|
||||||
RemoveInstance();
|
|
||||||
exit(1);
|
|
||||||
}
|
}
|
||||||
Py_INCREF(result);
|
Py_INCREF(result);
|
||||||
return result; // Caller must free this object
|
return result; // Caller must free this object
|
||||||
@@ -156,10 +244,10 @@ namespace pywrap {
|
|||||||
errorAbort("Couldn't call method score");
|
errorAbort("Couldn't call method score");
|
||||||
}
|
}
|
||||||
catch (const std::exception& e) {
|
catch (const std::exception& e) {
|
||||||
std::cerr << e.what() << '\n';
|
errorAbort(e.what());
|
||||||
RemoveInstance();
|
|
||||||
exit(1);
|
|
||||||
}
|
}
|
||||||
return PyFloat_AsDouble(result);
|
double resultValue = PyFloat_AsDouble(result);
|
||||||
|
Py_XDECREF(result);
|
||||||
|
return resultValue;
|
||||||
}
|
}
|
||||||
}
|
}
|
@@ -5,6 +5,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
#include "PyHelper.hpp"
|
#include "PyHelper.hpp"
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@@ -13,6 +14,7 @@ namespace pywrap {
|
|||||||
/*
|
/*
|
||||||
Singleton class to handle Python/numpy interpreter.
|
Singleton class to handle Python/numpy interpreter.
|
||||||
*/
|
*/
|
||||||
|
using json = nlohmann::json;
|
||||||
class PyWrap {
|
class PyWrap {
|
||||||
public:
|
public:
|
||||||
PyWrap() = default;
|
PyWrap() = default;
|
||||||
@@ -22,6 +24,7 @@ namespace pywrap {
|
|||||||
~PyWrap() = default;
|
~PyWrap() = default;
|
||||||
std::string callMethodString(const std::string& moduleName, const std::string& className, const std::string& method);
|
std::string callMethodString(const std::string& moduleName, const std::string& className, const std::string& method);
|
||||||
std::string version(const std::string& moduleName, const std::string& className);
|
std::string version(const std::string& moduleName, const std::string& className);
|
||||||
|
void setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters);
|
||||||
void fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y);
|
void fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y);
|
||||||
PyObject* predict(const std::string& moduleName, const std::string& className, CPyObject& X);
|
PyObject* predict(const std::string& moduleName, const std::string& className, CPyObject& X);
|
||||||
double score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y);
|
double score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y);
|
||||||
|
@@ -5,4 +5,11 @@ namespace pywrap {
|
|||||||
{
|
{
|
||||||
return callMethodString("1.0");
|
return callMethodString("1.0");
|
||||||
}
|
}
|
||||||
|
void SVC::setHyperparameters(const nlohmann::json& hyperparameters)
|
||||||
|
{
|
||||||
|
// Check if hyperparameters are valid
|
||||||
|
const std::vector<std::string> validKeys = { "C", "gamma", "kernel", "random_state" };
|
||||||
|
checkHyperparameters(validKeys, hyperparameters);
|
||||||
|
this->hyperparameters = hyperparameters;
|
||||||
|
}
|
||||||
} /* namespace pywrap */
|
} /* namespace pywrap */
|
@@ -8,6 +8,7 @@ namespace pywrap {
|
|||||||
SVC() : PyClassifier("sklearn.svm", "SVC") {};
|
SVC() : PyClassifier("sklearn.svm", "SVC") {};
|
||||||
~SVC() = default;
|
~SVC() = default;
|
||||||
std::string version();
|
std::string version();
|
||||||
|
void setHyperparameters(const nlohmann::json& hyperparameters) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} /* namespace pywrap */
|
} /* namespace pywrap */
|
||||||
|
@@ -43,6 +43,7 @@ tuple<Tensor, Tensor, vector<string>, string, map<string, vector<int>>> loadData
|
|||||||
|
|
||||||
int main(int argc, char* argv[])
|
int main(int argc, char* argv[])
|
||||||
{
|
{
|
||||||
|
using json = nlohmann::json;
|
||||||
cout << "* Begin." << endl;
|
cout << "* Begin." << endl;
|
||||||
{
|
{
|
||||||
auto datasetName = "iris";
|
auto datasetName = "iris";
|
||||||
@@ -52,10 +53,13 @@ int main(int argc, char* argv[])
|
|||||||
cout << "X: " << X.sizes() << endl;
|
cout << "X: " << X.sizes() << endl;
|
||||||
cout << "y: " << y.sizes() << endl;
|
cout << "y: " << y.sizes() << endl;
|
||||||
auto clf = pywrap::STree();
|
auto clf = pywrap::STree();
|
||||||
auto hyperparameters = nlohmann::json({ "max_depth": 3, "C" : 0.7 });
|
auto hyperparameters = json::parse("{\"C\": 0.7, \"max_iter\": 1e4, \"kernel\": \"rbf\"}");
|
||||||
clf.setHyperparameters(hyperparameters);
|
//clf.setHyperparameters(hyperparameters);
|
||||||
cout << "STree Version: " << clf.version() << endl;
|
cout << "STree Version: " << clf.version() << endl;
|
||||||
auto svc = pywrap::SVC();
|
auto svc = pywrap::SVC();
|
||||||
|
cout << "SVC with hyperparameters" << endl;
|
||||||
|
hyperparameters = json::parse("{\"kernel\": \"rbf\", \"C\": 0.7, \"random_state\": 17}");
|
||||||
|
svc.setHyperparameters(hyperparameters);
|
||||||
svc.fit(X, y, features, className, states);
|
svc.fit(X, y, features, className, states);
|
||||||
cout << "Graph: " << endl << clf.graph() << endl;
|
cout << "Graph: " << endl << clf.graph() << endl;
|
||||||
clf.fit(X, y, features, className, states);
|
clf.fit(X, y, features, className, states);
|
||||||
|
Reference in New Issue
Block a user