Add id to classifier
This commit is contained in:
@@ -36,9 +36,9 @@ namespace pywrap {
|
||||
wrapper = nullptr;
|
||||
}
|
||||
}
|
||||
void PyWrap::importClass(const std::string& moduleName, const std::string& className)
|
||||
void PyWrap::importClass(const clfId_t id, const std::string& moduleName, const std::string& className)
|
||||
{
|
||||
auto result = moduleClassMap.find({ moduleName, className });
|
||||
auto result = moduleClassMap.find(id);
|
||||
if (result != moduleClassMap.end()) {
|
||||
return;
|
||||
}
|
||||
@@ -58,12 +58,13 @@ namespace pywrap {
|
||||
module.AddRef();
|
||||
classObject.AddRef();
|
||||
instance.AddRef();
|
||||
moduleClassMap.insert({ { moduleName, className }, { module.getObject(), classObject.getObject(), instance.getObject() } });
|
||||
moduleClassMap.insert({ id, { module.getObject(), classObject.getObject(), instance.getObject() } });
|
||||
}
|
||||
void PyWrap::clean(const std::string& moduleName, const std::string& className)
|
||||
void PyWrap::clean(const clfId_t id)
|
||||
{
|
||||
// Remove Python interpreter if no more modules imported left
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
auto result = moduleClassMap.find({ moduleName, className });
|
||||
auto result = moduleClassMap.find(id);
|
||||
if (result == moduleClassMap.end()) {
|
||||
return;
|
||||
}
|
||||
@@ -73,7 +74,7 @@ namespace pywrap {
|
||||
moduleClassMap.erase(result);
|
||||
if (PyErr_Occurred()) {
|
||||
PyErr_Print();
|
||||
errorAbort("Error cleaning module " + moduleName + " and class " + className);
|
||||
errorAbort("Error cleaning module ");
|
||||
}
|
||||
if (moduleClassMap.empty()) {
|
||||
RemoveInstance();
|
||||
@@ -86,17 +87,17 @@ namespace pywrap {
|
||||
RemoveInstance();
|
||||
exit(1);
|
||||
}
|
||||
PyObject* PyWrap::getClass(const std::string& moduleName, const std::string& className)
|
||||
PyObject* PyWrap::getClass(const clfId_t id)
|
||||
{
|
||||
auto item = moduleClassMap.find({ moduleName, className });
|
||||
auto item = moduleClassMap.find(id);
|
||||
if (item == moduleClassMap.end()) {
|
||||
errorAbort("Module " + moduleName + " and class " + className + " not found");
|
||||
errorAbort("Module not found");
|
||||
}
|
||||
return std::get<2>(item->second);
|
||||
}
|
||||
std::string PyWrap::callMethodString(const std::string& moduleName, const std::string& className, const std::string& method)
|
||||
std::string PyWrap::callMethodString(const clfId_t id, const std::string& method)
|
||||
{
|
||||
PyObject* instance = getClass(moduleName, className);
|
||||
PyObject* instance = getClass(id);
|
||||
PyObject* result;
|
||||
try {
|
||||
if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL)))
|
||||
@@ -109,16 +110,16 @@ namespace pywrap {
|
||||
Py_XDECREF(result);
|
||||
return value;
|
||||
}
|
||||
std::string PyWrap::version(const std::string& moduleName, const std::string& className)
|
||||
std::string PyWrap::version(const clfId_t id)
|
||||
{
|
||||
return callMethodString(moduleName, className, "version");
|
||||
return callMethodString(id, "version");
|
||||
}
|
||||
void PyWrap::setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters)
|
||||
void PyWrap::setHyperparameters(const clfId_t id, const json& hyperparameters)
|
||||
{
|
||||
// Set hyperparameters as attributes of the class
|
||||
std::cout << "Building dictionary of arguments" << std::endl;
|
||||
PyObject* pValue;
|
||||
PyObject* instance = getClass(moduleName, className);
|
||||
PyObject* instance = getClass(id);
|
||||
for (const auto& [key, value] : hyperparameters.items()) {
|
||||
std::stringstream oss;
|
||||
oss << value.type_name();
|
||||
@@ -139,9 +140,9 @@ namespace pywrap {
|
||||
Py_XDECREF(pValue);
|
||||
}
|
||||
}
|
||||
void PyWrap::fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
|
||||
void PyWrap::fit(const clfId_t id, CPyObject& X, CPyObject& y)
|
||||
{
|
||||
PyObject* instance = getClass(moduleName, className);
|
||||
PyObject* instance = getClass(id);
|
||||
CPyObject result;
|
||||
std::string method = "fit";
|
||||
try {
|
||||
@@ -154,9 +155,9 @@ namespace pywrap {
|
||||
// Py_XDECREF(result);
|
||||
}
|
||||
|
||||
PyObject* PyWrap::predict(const std::string& moduleName, const std::string& className, CPyObject& X)
|
||||
PyObject* PyWrap::predict(const clfId_t id, CPyObject& X)
|
||||
{
|
||||
PyObject* instance = getClass(moduleName, className);
|
||||
PyObject* instance = getClass(id);
|
||||
PyObject* result;
|
||||
std::string method = "predict";
|
||||
try {
|
||||
@@ -169,9 +170,9 @@ namespace pywrap {
|
||||
Py_INCREF(result);
|
||||
return result; // Caller must free this object
|
||||
}
|
||||
double PyWrap::score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
|
||||
double PyWrap::score(const clfId_t id, CPyObject& X, CPyObject& y)
|
||||
{
|
||||
PyObject* instance = getClass(moduleName, className);
|
||||
PyObject* instance = getClass(id);
|
||||
CPyObject result;
|
||||
std::string method = "score";
|
||||
try {
|
||||
|
Reference in New Issue
Block a user