Add id to classifier

This commit is contained in:
2023-11-10 11:11:10 +01:00
parent 384b9071a2
commit 55286168cb
6 changed files with 56 additions and 46 deletions

View File

@@ -5,13 +5,15 @@ namespace pywrap {
namespace np = boost::python::numpy; namespace np = boost::python::numpy;
PyClassifier::PyClassifier(const std::string& module, const std::string& className) : module(module), className(className), fitted(false) PyClassifier::PyClassifier(const std::string& module, const std::string& className) : module(module), className(className), fitted(false)
{ {
id = 0;//reinterpret_cast<uint32_t>(&this); // This id allows to have more than one instance of the same module/class
id = reinterpret_cast<clfId_t>(this);
std::cout << "PyClassifier: Creating instance of " << module << " and class " << className << " id " << id << std::endl;
pyWrap = PyWrap::GetInstance(); pyWrap = PyWrap::GetInstance();
pyWrap->importClass(module, className); pyWrap->importClass(id, module, className);
} }
PyClassifier::~PyClassifier() PyClassifier::~PyClassifier()
{ {
pyWrap->clean(module, className); pyWrap->clean(id);
} }
np::ndarray tensor2numpy(torch::Tensor& X) np::ndarray tensor2numpy(torch::Tensor& X)
{ {
@@ -29,22 +31,22 @@ namespace pywrap {
} }
std::string PyClassifier::version() std::string PyClassifier::version()
{ {
return pyWrap->version(module, className); return pyWrap->version(id);
} }
std::string PyClassifier::callMethodString(const std::string& method) std::string PyClassifier::callMethodString(const std::string& method)
{ {
return pyWrap->callMethodString(module, className, method); return pyWrap->callMethodString(id, method);
} }
PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) PyClassifier& PyClassifier::fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states)
{ {
if (!fitted && hyperparameters.size() > 0) { if (!fitted && hyperparameters.size() > 0) {
std::cout << "PyClassifier: Setting hyperparameters" << std::endl; std::cout << "PyClassifier: Setting hyperparameters" << std::endl;
pyWrap->setHyperparameters(module, this->className, hyperparameters); pyWrap->setHyperparameters(id, hyperparameters);
} }
auto [Xn, yn] = tensors2numpy(X, y); auto [Xn, yn] = tensors2numpy(X, y);
CPyObject Xp = bp::incref(bp::object(Xn).ptr()); CPyObject Xp = bp::incref(bp::object(Xn).ptr());
CPyObject yp = bp::incref(bp::object(yn).ptr()); CPyObject yp = bp::incref(bp::object(yn).ptr());
pyWrap->fit(module, this->className, Xp, yp); pyWrap->fit(id, Xp, yp);
fitted = true; fitted = true;
return *this; return *this;
} }
@@ -53,7 +55,7 @@ namespace pywrap {
int dimension = X.size(1); int dimension = X.size(1);
auto Xn = tensor2numpy(X); auto Xn = tensor2numpy(X);
CPyObject Xp = bp::incref(bp::object(Xn).ptr()); CPyObject Xp = bp::incref(bp::object(Xn).ptr());
PyObject* incoming = pyWrap->predict(module, className, Xp); PyObject* incoming = pyWrap->predict(id, Xp);
bp::handle<> handle(incoming); bp::handle<> handle(incoming);
bp::object object(handle); bp::object object(handle);
np::ndarray prediction = np::from_object(object); np::ndarray prediction = np::from_object(object);
@@ -72,7 +74,7 @@ namespace pywrap {
auto [Xn, yn] = tensors2numpy(X, y); auto [Xn, yn] = tensors2numpy(X, y);
CPyObject Xp = bp::incref(bp::object(Xn).ptr()); CPyObject Xp = bp::incref(bp::object(Xn).ptr());
CPyObject yp = bp::incref(bp::object(yn).ptr()); CPyObject yp = bp::incref(bp::object(yn).ptr());
auto result = pyWrap->score(module, className, Xp, yp); auto result = pyWrap->score(id, Xp, yp);
return result; return result;
} }
void PyClassifier::setHyperparameters(const nlohmann::json& hyperparameters) void PyClassifier::setHyperparameters(const nlohmann::json& hyperparameters)

View File

@@ -10,6 +10,7 @@
#include <torch/torch.h> #include <torch/torch.h>
#include "PyWrap.h" #include "PyWrap.h"
#include "Classifier.h" #include "Classifier.h"
#include "TypeId.h"
namespace pywrap { namespace pywrap {
class PyClassifier : public Classifier { class PyClassifier : public Classifier {
@@ -29,7 +30,7 @@ namespace pywrap {
PyWrap* pyWrap; PyWrap* pyWrap;
std::string module; std::string module;
std::string className; std::string className;
uint32_t id; clfId_t id;
bool fitted; bool fitted;
}; };
} /* namespace pywrap */ } /* namespace pywrap */

View File

@@ -36,9 +36,9 @@ namespace pywrap {
wrapper = nullptr; wrapper = nullptr;
} }
} }
void PyWrap::importClass(const std::string& moduleName, const std::string& className) void PyWrap::importClass(const clfId_t id, const std::string& moduleName, const std::string& className)
{ {
auto result = moduleClassMap.find({ moduleName, className }); auto result = moduleClassMap.find(id);
if (result != moduleClassMap.end()) { if (result != moduleClassMap.end()) {
return; return;
} }
@@ -58,12 +58,13 @@ namespace pywrap {
module.AddRef(); module.AddRef();
classObject.AddRef(); classObject.AddRef();
instance.AddRef(); instance.AddRef();
moduleClassMap.insert({ { moduleName, className }, { module.getObject(), classObject.getObject(), instance.getObject() } }); moduleClassMap.insert({ id, { module.getObject(), classObject.getObject(), instance.getObject() } });
} }
void PyWrap::clean(const std::string& moduleName, const std::string& className) void PyWrap::clean(const clfId_t id)
{ {
// Remove Python interpreter if no more modules imported left
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
auto result = moduleClassMap.find({ moduleName, className }); auto result = moduleClassMap.find(id);
if (result == moduleClassMap.end()) { if (result == moduleClassMap.end()) {
return; return;
} }
@@ -73,7 +74,7 @@ namespace pywrap {
moduleClassMap.erase(result); moduleClassMap.erase(result);
if (PyErr_Occurred()) { if (PyErr_Occurred()) {
PyErr_Print(); PyErr_Print();
errorAbort("Error cleaning module " + moduleName + " and class " + className); errorAbort("Error cleaning module ");
} }
if (moduleClassMap.empty()) { if (moduleClassMap.empty()) {
RemoveInstance(); RemoveInstance();
@@ -86,17 +87,17 @@ namespace pywrap {
RemoveInstance(); RemoveInstance();
exit(1); exit(1);
} }
PyObject* PyWrap::getClass(const std::string& moduleName, const std::string& className) PyObject* PyWrap::getClass(const clfId_t id)
{ {
auto item = moduleClassMap.find({ moduleName, className }); auto item = moduleClassMap.find(id);
if (item == moduleClassMap.end()) { if (item == moduleClassMap.end()) {
errorAbort("Module " + moduleName + " and class " + className + " not found"); errorAbort("Module not found");
} }
return std::get<2>(item->second); return std::get<2>(item->second);
} }
std::string PyWrap::callMethodString(const std::string& moduleName, const std::string& className, const std::string& method) std::string PyWrap::callMethodString(const clfId_t id, const std::string& method)
{ {
PyObject* instance = getClass(moduleName, className); PyObject* instance = getClass(id);
PyObject* result; PyObject* result;
try { try {
if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL))) if (!(result = PyObject_CallMethod(instance, method.c_str(), NULL)))
@@ -109,16 +110,16 @@ namespace pywrap {
Py_XDECREF(result); Py_XDECREF(result);
return value; return value;
} }
std::string PyWrap::version(const std::string& moduleName, const std::string& className) std::string PyWrap::version(const clfId_t id)
{ {
return callMethodString(moduleName, className, "version"); return callMethodString(id, "version");
} }
void PyWrap::setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters) void PyWrap::setHyperparameters(const clfId_t id, const json& hyperparameters)
{ {
// Set hyperparameters as attributes of the class // Set hyperparameters as attributes of the class
std::cout << "Building dictionary of arguments" << std::endl; std::cout << "Building dictionary of arguments" << std::endl;
PyObject* pValue; PyObject* pValue;
PyObject* instance = getClass(moduleName, className); PyObject* instance = getClass(id);
for (const auto& [key, value] : hyperparameters.items()) { for (const auto& [key, value] : hyperparameters.items()) {
std::stringstream oss; std::stringstream oss;
oss << value.type_name(); oss << value.type_name();
@@ -139,9 +140,9 @@ namespace pywrap {
Py_XDECREF(pValue); Py_XDECREF(pValue);
} }
} }
void PyWrap::fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y) void PyWrap::fit(const clfId_t id, CPyObject& X, CPyObject& y)
{ {
PyObject* instance = getClass(moduleName, className); PyObject* instance = getClass(id);
CPyObject result; CPyObject result;
std::string method = "fit"; std::string method = "fit";
try { try {
@@ -154,9 +155,9 @@ namespace pywrap {
// Py_XDECREF(result); // Py_XDECREF(result);
} }
PyObject* PyWrap::predict(const std::string& moduleName, const std::string& className, CPyObject& X) PyObject* PyWrap::predict(const clfId_t id, CPyObject& X)
{ {
PyObject* instance = getClass(moduleName, className); PyObject* instance = getClass(id);
PyObject* result; PyObject* result;
std::string method = "predict"; std::string method = "predict";
try { try {
@@ -169,9 +170,9 @@ namespace pywrap {
Py_INCREF(result); Py_INCREF(result);
return result; // Caller must free this object return result; // Caller must free this object
} }
double PyWrap::score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y) double PyWrap::score(const clfId_t id, CPyObject& X, CPyObject& y)
{ {
PyObject* instance = getClass(moduleName, className); PyObject* instance = getClass(id);
CPyObject result; CPyObject result;
std::string method = "score"; std::string method = "score";
try { try {

View File

@@ -7,6 +7,7 @@
#include <mutex> #include <mutex>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
#include "PyHelper.hpp" #include "PyHelper.hpp"
#include "TypeId.h"
#pragma once #pragma once
@@ -22,21 +23,21 @@ namespace pywrap {
static PyWrap* GetInstance(); static PyWrap* GetInstance();
void operator=(const PyWrap&) = delete; void operator=(const PyWrap&) = delete;
~PyWrap() = default; ~PyWrap() = default;
std::string callMethodString(const std::string& moduleName, const std::string& className, const std::string& method); std::string callMethodString(const clfId_t id, const std::string& method);
std::string version(const std::string& moduleName, const std::string& className); std::string version(const clfId_t id);
void setHyperparameters(const std::string& moduleName, const std::string& className, const json& hyperparameters); void setHyperparameters(const clfId_t id, const json& hyperparameters);
void fit(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y); void fit(const clfId_t id, CPyObject& X, CPyObject& y);
PyObject* predict(const std::string& moduleName, const std::string& className, CPyObject& X); PyObject* predict(const clfId_t id, CPyObject& X);
double score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y); double score(const clfId_t id, CPyObject& X, CPyObject& y);
void clean(const std::string& moduleName, const std::string& className); void clean(const clfId_t id);
void importClass(const std::string& moduleName, const std::string& className); void importClass(const clfId_t id, const std::string& moduleName, const std::string& className);
PyObject* getClass(const std::string& moduleName, const std::string& className); PyObject* getClass(const clfId_t id);
private: private:
// Only call RemoveInstance from clean method // Only call RemoveInstance from clean method
static void RemoveInstance(); static void RemoveInstance();
void errorAbort(const std::string& message); void errorAbort(const std::string& message);
// No need to use static map here, since this class is a singleton // No need to use static map here, since this class is a singleton
std::map<std::pair<std::string, std::string>, std::tuple<PyObject*, PyObject*, PyObject*>> moduleClassMap; std::map<clfId_t, std::tuple<PyObject*, PyObject*, PyObject*>> moduleClassMap;
static CPyInstance* pyInstance; static CPyInstance* pyInstance;
static PyWrap* wrapper; static PyWrap* wrapper;
static std::mutex mutex; static std::mutex mutex;

4
src/TypeId.h Normal file
View File

@@ -0,0 +1,4 @@
#ifndef TYPEDEF_H
#define TYPEDEF_H
typedef uint64_t clfId_t;
#endif /* TYPEDEF_H */

View File

@@ -8,6 +8,7 @@
#include "STree.h" #include "STree.h"
#include "SVC.h" #include "SVC.h"
#include "RandomForest.h" #include "RandomForest.h"
#include "XGBoost.h"
using namespace std; using namespace std;
using namespace torch; using namespace torch;
@@ -53,16 +54,16 @@ int main(int argc, char* argv[])
cout << "X: " << X.sizes() << endl; cout << "X: " << X.sizes() << endl;
cout << "y: " << y.sizes() << endl; cout << "y: " << y.sizes() << endl;
auto clf = pywrap::STree(); auto clf = pywrap::STree();
// auto stree = pywrap::STree(); auto stree = pywrap::STree();
auto hyperparameters = json::parse("{\"C\": 0.7, \"max_iter\": 10000, \"kernel\": \"rbf\", \"random_state\": 17}"); auto hyperparameters = json::parse("{\"C\": 0.7, \"max_iter\": 10000, \"kernel\": \"rbf\", \"random_state\": 17}");
// stree.setHyperparameters(hyperparameters); stree.setHyperparameters(hyperparameters);
cout << "STree Version: " << clf.version() << endl; cout << "STree Version: " << clf.version() << endl;
auto svc = pywrap::SVC(); auto svc = pywrap::SVC();
cout << "SVC with hyperparameters" << endl; cout << "SVC with hyperparameters" << endl;
svc.fit(X, y, features, className, states); svc.fit(X, y, features, className, states);
cout << "Graph: " << endl << clf.graph() << endl; cout << "Graph: " << endl << clf.graph() << endl;
clf.fit(X, y, features, className, states); clf.fit(X, y, features, className, states);
// stree.fit(X, y, features, className, states); stree.fit(X, y, features, className, states);
auto prediction = clf.predict(X); auto prediction = clf.predict(X);
cout << "Prediction: " << endl << "{"; cout << "Prediction: " << endl << "{";
for (int i = 0; i < prediction.size(0); ++i) { for (int i = 0; i < prediction.size(0); ++i) {
@@ -71,10 +72,10 @@ int main(int argc, char* argv[])
cout << "}" << endl; cout << "}" << endl;
auto rf = pywrap::RandomForest(); auto rf = pywrap::RandomForest();
rf.fit(X, y, features, className, states); rf.fit(X, y, features, className, states);
auto xg = pywrap::RandomForest(); auto xg = pywrap::XGBoost();
xg.fit(X, y, features, className, states); xg.fit(X, y, features, className, states);
cout << "STree Score ......: " << clf.score(X, y) << endl; cout << "STree Score ......: " << clf.score(X, y) << endl;
// cout << "STree hyper score : " << stree.score(X, y) << endl; cout << "STree hyper score : " << stree.score(X, y) << endl;
cout << "RandomForest Score: " << rf.score(X, y) << endl; cout << "RandomForest Score: " << rf.score(X, y) << endl;
cout << "SVC Score ........: " << svc.score(X, y) << endl; cout << "SVC Score ........: " << svc.score(X, y) << endl;
cout << "XGBoost Score ....: " << xg.score(X, y) << endl; cout << "XGBoost Score ....: " << xg.score(X, y) << endl;