Add id to classifier

This commit is contained in:
2023-11-10 11:11:10 +01:00
parent 384b9071a2
commit 55286168cb
6 changed files with 56 additions and 46 deletions

View File

@@ -36,9 +36,9 @@ namespace pywrap {
wrapper = nullptr;
}
}
void PyWrap::importClass(const std::string& moduleName, const std::string& className)
void PyWrap::importClass(const clfId_t id, const std::string& moduleName, const std::string& className)
{
auto result = moduleClassMap.find({ moduleName, className });
auto result = moduleClassMap.find(id);
if (result != moduleClassMap.end()) {
return;
}
@@ -58,12 +58,13 @@ namespace pywrap {
module.AddRef();
classObject.AddRef();
instance.AddRef();
moduleClassMap.insert({ { moduleName, className }, { module.getObject(), classObject.getObject(), instance.getObject() } });
moduleClassMap.insert({ id, { module.getObject(), classObject.getObject(), instance.getObject() } });
}
void PyWrap::clean(const std::string& moduleName, const std::string& className)
void PyWrap::clean(const clfId_t id)
{
// Remove Python interpreter if no more modules imported left
std::lock_guard<std::mutex> lock(mutex);
auto result = moduleClassMap.find({ moduleName, className });
auto result = moduleClassMap.find(id);
if (result == moduleClassMap.end()) {
return;
}
@@ -73,7 +74,7 @@ namespace pywrap {
moduleClassMap.erase(result);
if (PyErr_Occurred()) {
PyErr_Print();
errorAbort("Error cleaning module " + moduleName + " and class " + className);
errorAbort("Error cleaning module ");
}
if (moduleClassMap.empty()) {
RemoveInstance();
@@ -86,17 +87,17 @@ namespace pywrap {
RemoveInstance();
exit(1);
}
PyObject* PyWrap::getClass(const std::string& moduleName, const std::string& className)
PyObject* PyWrap::getClass(const clfId_t id)
{
auto item = moduleClassMap.find({ moduleName, className });
auto item = moduleClassMap.find(id);
if (item == moduleClassMap.end()) {
errorAbort("Module " + moduleName + " and class " + className + " not found");
errorAbort("Module not found");
}
return std::get<2>(item->second);
}
std::string PyWrap::callMethodString(const std::string& moduleName, const std::string& className, const std::string& method)
std::string PyWrap::callMethodString(const clfId_t id, const std::string& method)
{
PyObject* instance = getClass(moduleName, className);
PyObject* instance = getClass(id);
PyObject* result;
try {
if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL)))
@@ -109,16 +110,16 @@ namespace pywrap {
Py_XDECREF(result);
return value;
}
std::string PyWrap::version(const std::string& moduleName, const std::string& className)
std::string PyWrap::version(const clfId_t id)
{
return callMethodString(moduleName, className, "version");
return callMethodString(id, "version");
}
void PyWrap::setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters)
void PyWrap::setHyperparameters(const clfId_t id, const json& hyperparameters)
{
// Set hyperparameters as attributes of the class
std::cout << "Building dictionary of arguments" << std::endl;
PyObject* pValue;
PyObject* instance = getClass(moduleName, className);
PyObject* instance = getClass(id);
for (const auto& [key, value] : hyperparameters.items()) {
std::stringstream oss;
oss << value.type_name();
@@ -139,9 +140,9 @@ namespace pywrap {
Py_XDECREF(pValue);
}
}
void PyWrap::fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
void PyWrap::fit(const clfId_t id, CPyObject& X, CPyObject& y)
{
PyObject* instance = getClass(moduleName, className);
PyObject* instance = getClass(id);
CPyObject result;
std::string method = "fit";
try {
@@ -154,9 +155,9 @@ namespace pywrap {
// Py_XDECREF(result);
}
PyObject* PyWrap::predict(const std::string& moduleName, const std::string& className, CPyObject& X)
PyObject* PyWrap::predict(const clfId_t id, CPyObject& X)
{
PyObject* instance = getClass(moduleName, className);
PyObject* instance = getClass(id);
PyObject* result;
std::string method = "predict";
try {
@@ -169,9 +170,9 @@ namespace pywrap {
Py_INCREF(result);
return result; // Caller must free this object
}
double PyWrap::score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
double PyWrap::score(const clfId_t id, CPyObject& X, CPyObject& y)
{
PyObject* instance = getClass(moduleName, className);
PyObject* instance = getClass(id);
CPyObject result;
std::string method = "score";
try {