Refactor Classifier classes

This commit is contained in:
2023-11-12 18:35:29 +01:00
parent c7372b7fc7
commit 0059e269dd
12 changed files with 133 additions and 83 deletions

View File

@@ -3,5 +3,6 @@ include_directories(${PyWrap_SOURCE_DIR}/lib/json/include)
include_directories(${Python3_INCLUDE_DIRS})
include_directories(${TORCH_INCLUDE_DIRS})
add_library(PyWrap SHARED PyWrap.cc STree.cc SVC.cc RandomForest.cc PyClassifier.cc)
add_library(PyWrap SHARED PyWrap.cc STree.cc ODTE.cc SVC.cc RandomForest.cc PyClassifier.cc)
#target_link_libraries(PyWrap ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy xgboost::xgboost ArffFiles)
target_link_libraries(PyWrap ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy ArffFiles)

View File

@@ -1,13 +1,25 @@
#ifndef CLASSIFER_H
#define CLASSIFER_H
#ifndef CLASSIFIER_H
#define CLASSIFIER_H
#include <torch/torch.h>
#include <nlohmann/json.hpp>
#include <string>
#include <map>
#include <vector>
namespace pywrap {
class Classifier {
public:
Classifier() = default;
virtual ~Classifier() = default;
virtual Classifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) = 0;
virtual Classifier& fit(torch::Tensor& X, torch::Tensor& y) = 0;
virtual torch::Tensor predict(torch::Tensor& X) = 0;
virtual double score(torch::Tensor& X, torch::Tensor& y) = 0;
virtual std::string version() = 0;
virtual std::string sklearnVersion() = 0;
virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0;
protected:
virtual void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters) = 0;
};
} /* namespace pywrap */
#endif /* CLASSIFER_H */
#endif /* CLASSIFIER_H */

15
src/ODTE.cc Normal file
View File

@@ -0,0 +1,15 @@
#include "ODTE.h"
namespace pywrap {
std::string ODTE::graph()
{
return callMethodString("graph");
}
void ODTE::setHyperparameters(const nlohmann::json& hyperparameters)
{
// Check if hyperparameters are valid
const std::vector<std::string> validKeys = { "n_jobs", "n_estimators", "random_state" };
checkHyperparameters(validKeys, hyperparameters);
this->hyperparameters = hyperparameters;
}
} /* namespace pywrap */

15
src/ODTE.h Normal file
View File

@@ -0,0 +1,15 @@
#ifndef ODTE_H
#define ODTE_H
#include "nlohmann/json.hpp"
#include "PyClassifier.h"
namespace pywrap {
class ODTE : public PyClassifier {
public:
ODTE() : PyClassifier("odte", "Odte") {};
~ODTE() = default;
std::string graph();
void setHyperparameters(const nlohmann::json& hyperparameters) override;
};
} /* namespace pywrap */
#endif /* ODTE_H */

View File

@@ -31,6 +31,10 @@ namespace pywrap {
{
return pyWrap->version(id);
}
std::string PyClassifier::sklearnVersion()
{
return pyWrap->sklearnVersion();
}
std::string PyClassifier::callMethodString(const std::string& method)
{
return pyWrap->callMethodString(id, method);

View File

@@ -1,5 +1,5 @@
#ifndef PYCLASSIFER_H
#define PYCLASSIFER_H
#ifndef PYCLASSIFIER_H
#define PYCLASSIFIER_H
#include "boost/python/detail/wrap_python.hpp"
#include <boost/python/numpy.hpp>
#include <nlohmann/json.hpp>
@@ -17,15 +17,16 @@ namespace pywrap {
public:
PyClassifier(const std::string& module, const std::string& className);
virtual ~PyClassifier();
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states);
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y);
torch::Tensor predict(torch::Tensor& X);
double score(torch::Tensor& X, torch::Tensor& y);
std::string version();
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, std::map<std::string, std::vector<int>>& states) override;
PyClassifier& fit(torch::Tensor& X, torch::Tensor& y) override;
torch::Tensor predict(torch::Tensor& X) override;
double score(torch::Tensor& X, torch::Tensor& y) override;
std::string version() override;
std::string sklearnVersion() override;
std::string callMethodString(const std::string& method);
void setHyperparameters(const nlohmann::json& hyperparameters) override;
protected:
void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters);
void checkHyperparameters(const std::vector<std::string>& validKeys, const nlohmann::json& hyperparameters) override;
nlohmann::json hyperparameters;
private:
PyWrap* pyWrap;
@@ -35,4 +36,4 @@ namespace pywrap {
bool fitted;
};
} /* namespace pywrap */
#endif /* PYCLASSIFER_H */
#endif /* PYCLASSIFIER_H */

View File

@@ -42,7 +42,6 @@ namespace pywrap {
if (result != moduleClassMap.end()) {
return;
}
std::cout << "1a" << std::endl;
PyObject* module = PyImport_ImportModule(moduleName.c_str());
if (PyErr_Occurred()) {
errorAbort("Couldn't import module " + moduleName);
@@ -107,6 +106,13 @@ namespace pywrap {
Py_XDECREF(result);
return value;
}
std::string PyWrap::sklearnVersion()
{
return "1.0";
// CPyObject data = PyRun_SimpleString("import sklearn;return sklearn.__version__");
// std::string result = PyUnicode_AsUTF8(data);
// return result;
}
std::string PyWrap::version(const clfId_t id)
{
return callMethodString(id, "version");

View File

@@ -24,6 +24,7 @@ namespace pywrap {
void operator=(const PyWrap&) = delete;
~PyWrap() = default;
std::string callMethodString(const clfId_t id, const std::string& method);
std::string sklearnVersion();
std::string version(const clfId_t id);
void setHyperparameters(const clfId_t id, const json& hyperparameters);
void fit(const clfId_t id, CPyObject& X, CPyObject& y);

View File

@@ -3,6 +3,6 @@
namespace pywrap {
std::string RandomForest::version()
{
return callMethodString("1.0");
return sklearnVersion();
}
} /* namespace pywrap */

View File

@@ -3,7 +3,7 @@
namespace pywrap {
std::string SVC::version()
{
return callMethodString("1.0");
return sklearnVersion();
}
void SVC::setHyperparameters(const nlohmann::json& hyperparameters)
{