Set attributes

This commit is contained in:
2023-11-10 10:42:01 +01:00
parent 67e4b0af47
commit 8334f81276
4 changed files with 26 additions and 88 deletions

View File

@@ -5,6 +5,7 @@ namespace pywrap {
namespace np = boost::python::numpy; namespace np = boost::python::numpy;
PyClassifier::PyClassifier(const std::string& module, const std::string& className) : module(module), className(className), fitted(false) PyClassifier::PyClassifier(const std::string& module, const std::string& className) : module(module), className(className), fitted(false)
{ {
id = reinterpret_cast<uint32_t>(&this);
pyWrap = PyWrap::GetInstance(); pyWrap = PyWrap::GetInstance();
pyWrap->importClass(module, className); pyWrap->importClass(module, className);
} }

View File

@@ -29,6 +29,7 @@ namespace pywrap {
PyWrap* pyWrap; PyWrap* pyWrap;
std::string module; std::string module;
std::string className; std::string className;
uint32_t id;
bool fitted; bool fitted;
}; };
} /* namespace pywrap */ } /* namespace pywrap */

View File

@@ -4,6 +4,7 @@
#include <string> #include <string>
#include <map> #include <map>
#include <iostream> #include <iostream>
#include <sstream>
#include <boost/python/numpy.hpp> #include <boost/python/numpy.hpp>
namespace pywrap { namespace pywrap {
@@ -112,97 +113,31 @@ namespace pywrap {
{ {
return callMethodString(moduleName, className, "version"); return callMethodString(moduleName, className, "version");
} }
// void printPyObject(PyObject* obj)
// {
// PyObject* pStr = PyObject_Str(obj);
// const char* str = PyUnicode_AsUTF8(pStr);
// printf("%s\n", str);
// Py_XDECREF(pStr);
// }
// void printDictionary(PyObject* pDict)
// {
// PyObject* pKeys = PyDict_Keys(pDict);
// Py_ssize_t size = PyList_Size(pKeys);
// for (Py_ssize_t i = 0; i < size; ++i) {
// PyObject* pKey = PyList_GetItem(pKeys, i);
// PyObject* pValue = PyDict_GetItem(pDict, pKey);
// printf("%s: ", PyUnicode_AsUTF8(pKey));
// printPyObject(pValue);
// }
// Py_XDECREF(pKeys);
// }
void cleanDictionary(PyObject* pDict)
{
PyObject* pKeys = PyDict_Keys(pDict);
Py_ssize_t size = PyList_Size(pKeys);
for (Py_ssize_t i = 0; i < size; ++i) {
PyObject* pKey = PyList_GetItem(pKeys, i);
PyObject* pValue = PyDict_GetItem(pDict, pKey);
Py_XDECREF(pKey);
Py_XDECREF(pValue);
}
Py_XDECREF(pKeys);
Py_XDECREF(pDict);
}
void PyWrap::setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters) void PyWrap::setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters)
{ {
PyObject* args = PyDict_New(); // Set hyperparameters as attributes of the class
// Build dictionary of arguments with a little help of chatGPT
std::cout << "Building dictionary of arguments" << std::endl; std::cout << "Building dictionary of arguments" << std::endl;
try { PyObject* pValue;
PyObject* pValue; PyObject* instance = getClass(moduleName, className);
for (const auto& [key, value] : hyperparameters.items()) { for (const auto& [key, value] : hyperparameters.items()) {
std::string type_name; std::stringstream oss;
if (value.type_name() == "string") { oss << value.type_name();
type_name = "s"; if (oss.str() == "string") {
pValue = Py_BuildValue("s", value.get<std::string>().c_str()); pValue = Py_BuildValue("s", value.get<std::string>().c_str());
std::cout << key << " s " << value.get<std::string>() << std::endl; } else {
if (value.is_number_integer()) {
pValue = Py_BuildValue("i", value.get<int>());
} else { } else {
if (value.is_number_integer()) { pValue = Py_BuildValue("f", value.get<double>());
pValue = Py_BuildValue("i", value.get<int>());
std::cout << key << " i " << value.get<int>() << std::endl;
} else {
pValue = Py_BuildValue("f", value.get<double>());
std::cout << key << " f " << value.get<double>() << std::endl;
}
} }
PyDict_SetItemString(args, key.c_str(), pValue);
Py_XDECREF(pValue);
} }
} int res = PyObject_SetAttrString(instance, key.c_str(), pValue);
catch (const std::exception& e) { if (res == -1 && PyErr_Occurred()) {
Py_DECREF(args);
errorAbort(e.what());
}
std::cout << "PyDict_Size=" << PyDict_Size(args) << std::endl;
std::cout << "Calling method set_args with" << std::endl;
//printDictionary(args);
Py_INCREF(args);
PyObject* result;
// Call the method with the argument dictionary with a little help of chatGPT
auto instance = getClass(moduleName, className);
try {
if (!(result = PyObject_CallMethod(instance, "set_params", "O", args))) {
//if (!(result = PyObject_Call(instance, PyObject_GetAttrString(instance, "set_params"), args, nullptr)))
std::cout << "Cleaning up because of error" << std::endl;
cleanDictionary(args); cleanDictionary(args);
errorAbort("Couldn't call method set_args"); errorAbort("Couldn't set attribute " + key + "=" + value.dump());
} }
Py_XDECREF(pValue);
} }
catch (const std::exception& e) {
std::cout << "Cleaning up because of exception" << std::endl;
cleanDictionary(args);
errorAbort(e.what());
}
std::cout << "Cleaning up everything went ok!" << std::endl;
cleanDictionary(args);
Py_XDECREF(result);
} }
void PyWrap::fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y) void PyWrap::fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
{ {
@@ -216,7 +151,7 @@ namespace pywrap {
catch (const std::exception& e) { catch (const std::exception& e) {
errorAbort(e.what()); errorAbort(e.what());
} }
Py_XDECREF(result); // Py_XDECREF(result);
} }
PyObject* PyWrap::predict(const std::string& moduleName, const std::string& className, CPyObject& X) PyObject* PyWrap::predict(const std::string& moduleName, const std::string& className, CPyObject& X)
@@ -247,7 +182,7 @@ namespace pywrap {
errorAbort(e.what()); errorAbort(e.what());
} }
double resultValue = PyFloat_AsDouble(result); double resultValue = PyFloat_AsDouble(result);
Py_XDECREF(result); // Py_XDECREF(result);
return resultValue; return resultValue;
} }
} }

View File

@@ -53,16 +53,16 @@ int main(int argc, char* argv[])
cout << "X: " << X.sizes() << endl; cout << "X: " << X.sizes() << endl;
cout << "y: " << y.sizes() << endl; cout << "y: " << y.sizes() << endl;
auto clf = pywrap::STree(); auto clf = pywrap::STree();
auto hyperparameters = json::parse("{\"C\": 0.7, \"max_iter\": 1e4, \"kernel\": \"rbf\"}"); // auto stree = pywrap::STree();
//clf.setHyperparameters(hyperparameters); auto hyperparameters = json::parse("{\"C\": 0.7, \"max_iter\": 10000, \"kernel\": \"rbf\", \"random_state\": 17}");
// stree.setHyperparameters(hyperparameters);
cout << "STree Version: " << clf.version() << endl; cout << "STree Version: " << clf.version() << endl;
auto svc = pywrap::SVC(); auto svc = pywrap::SVC();
cout << "SVC with hyperparameters" << endl; cout << "SVC with hyperparameters" << endl;
hyperparameters = json::parse("{\"kernel\": \"rbf\", \"C\": 0.7, \"random_state\": 17}");
svc.setHyperparameters(hyperparameters);
svc.fit(X, y, features, className, states); svc.fit(X, y, features, className, states);
cout << "Graph: " << endl << clf.graph() << endl; cout << "Graph: " << endl << clf.graph() << endl;
clf.fit(X, y, features, className, states); clf.fit(X, y, features, className, states);
// stree.fit(X, y, features, className, states);
auto prediction = clf.predict(X); auto prediction = clf.predict(X);
cout << "Prediction: " << endl << "{"; cout << "Prediction: " << endl << "{";
for (int i = 0; i < prediction.size(0); ++i) { for (int i = 0; i < prediction.size(0); ++i) {
@@ -74,6 +74,7 @@ int main(int argc, char* argv[])
auto xg = pywrap::RandomForest(); auto xg = pywrap::RandomForest();
xg.fit(X, y, features, className, states); xg.fit(X, y, features, className, states);
cout << "STree Score ......: " << clf.score(X, y) << endl; cout << "STree Score ......: " << clf.score(X, y) << endl;
// cout << "STree hyper score : " << stree.score(X, y) << endl;
cout << "RandomForest Score: " << rf.score(X, y) << endl; cout << "RandomForest Score: " << rf.score(X, y) << endl;
cout << "SVC Score ........: " << svc.score(X, y) << endl; cout << "SVC Score ........: " << svc.score(X, y) << endl;
cout << "XGBoost Score ....: " << xg.score(X, y) << endl; cout << "XGBoost Score ....: " << xg.score(X, y) << endl;