From 0059e269dd662c8ab7a2730b0b1ae19599946a3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 12 Nov 2023 18:35:29 +0100 Subject: [PATCH] Refactor Classifier classes --- CMakeLists.txt | 1 + example/example.cc | 128 +++++++++++++++++++++----------------------- src/CMakeLists.txt | 3 +- src/Classifier.h | 18 +++++-- src/ODTE.cc | 15 ++++++ src/ODTE.h | 15 ++++++ src/PyClassifier.cc | 4 ++ src/PyClassifier.h | 19 +++---- src/PyWrap.cc | 8 ++- src/PyWrap.h | 1 + src/RandomForest.cc | 2 +- src/SVC.cc | 2 +- 12 files changed, 133 insertions(+), 83 deletions(-) create mode 100644 src/ODTE.cc create mode 100644 src/ODTE.h diff --git a/CMakeLists.txt b/CMakeLists.txt index c7227b3..049cb64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,6 +23,7 @@ include(AddGitSubmodule) find_package(Python3 3.11...3.11.9 COMPONENTS Interpreter Development REQUIRED) find_package(Torch REQUIRED) find_package(Boost REQUIRED COMPONENTS python3 numpy3) +# find_package(xgboost REQUIRED) # Temporary patch while find_package(Torch) is not fixed file( diff --git a/example/example.cc b/example/example.cc index 4690d95..ecf296d 100644 --- a/example/example.cc +++ b/example/example.cc @@ -5,7 +5,9 @@ #include #include #include +#include "Classifier.h" #include "STree.h" +#include "ODTE.h" #include "SVC.h" #include "RandomForest.h" #include "XGBoost.h" @@ -47,11 +49,22 @@ pair get_train_test_indices(int size) shuffle(indices.begin(), indices.end(), std::default_random_engine(seed)); auto train_indices = torch::zeros({ train_size }, torch::kInt32); auto test_indices = torch::zeros({ test_size }, torch::kInt32); + int ti = 0, ei = 0; + cout << "Train indices ["; + for (auto i = 0; i < train_size; ++i) { + cout << indices.at(i) << ", "; + } + cout << "]" << endl; + cout << "Test indices ["; + for (auto i = train_size; i < size; ++i) { + cout << indices.at(i) << ", "; + } + cout << "]" << endl; for (auto i = 0; i < size; ++i) { if (i < train_size) { - train_indices[i] = indices[i]; + train_indices[ti++] = indices.at(i); } else if (i < size) { - test_indices[i - train_size] = indices[i]; + test_indices[ei++] = indices.at(i); } } return { train_indices, test_indices }; @@ -61,71 +74,52 @@ int main(int argc, char* argv[]) { using json = nlohmann::json; cout << "* Begin." << endl; - { - using namespace torch::indexing; - auto datasetName = "wine"; - bool class_last = true; - auto [X, y] = loadDataset(datasetName, class_last); - // Split train/test - auto [train_indices, test_indices] = get_train_test_indices(X.size(1)); - auto Xtrain = X.index({ "...", train_indices }); - auto ytrain = y.index({ train_indices }); - auto Xtest = X.index({ "...", test_indices }); - auto ytest = y.index({ test_indices }); - cout << "Dataset: " << datasetName << endl; - cout << "X: " << X.sizes() << endl; - cout << "y: " << y.sizes() << endl; - cout << "Xtrain: " << Xtrain.sizes() << endl; - cout << "ytrain: " << ytrain.sizes() << endl; - cout << "Xtest : " << Xtest.sizes() << endl; - cout << "ytest : " << ytest.sizes() << endl; - // - // STree - // - auto clf = pywrap::STree(); - clf.fit(Xtrain, ytest); - double clf_score = clf.score(Xtest, ytest); - // auto stree = pywrap::STree(); - // auto hyperparameters = json::parse("{\"C\": 0.7, \"max_iter\": 10000, \"kernel\": \"rbf\", \"random_state\": 17}"); - // stree.setHyperparameters(hyperparameters); - // cout << "STree Version: " << clf.version() << endl; - // auto prediction = clf.predict(X); - // cout << "Prediction: " << endl << "{"; - // for (int i = 0; i < prediction.size(0); ++i) { - // cout << prediction[i].item() << ", "; - // } - // cout << "}" << endl; - // - // SVC - // - // auto svc = pywrap::SVC(); - // cout << "SVC with hyperparameters" << endl; - // svc.fit(Xtrain, ytrain); - // - // Random Forest - // - // cout << "Building Random Forest" << endl; - // auto rf = pywrap::RandomForest(); - // rf.fit(Xtrain, ytrain); - // - // XGBoost - // - // cout << "Building XGBoost" << endl; - // auto xg = pywrap::XGBoost(); - // cout << "Fitting XGBoost" << endl; - // xg.fit(Xtrain, ytrain); - // double xg_score = xg.score(Xtest, ytest); - // - // Scoring - // - cout << "Scoring dataset: " << datasetName << endl; - cout << "Scores:" << endl; - cout << "STree Score ......: " << clf_score << endl; - // cout << "STree train/test .: " << clf.fit(Xtrain, ytrain).score(Xtest, ytest) << endl; - // cout << "STree hyper score : " << stree.fit(Xtrain, ytrain).score(Xtest, ytest) << endl; - // cout << "RandomForest Score: " << rf.score(Xtest, ytest) << endl; - // cout << "SVC Score ........: " << svc.score(Xtest, ytest) << endl; - // cout << "XGBoost Score ....: " << xg_score << endl; + using namespace torch::indexing; + map classifiers = { + {"STree", new pywrap::STree()}, {"SVC", new pywrap::SVC()}, + {"RandomForest", new pywrap::RandomForest()},// {"XGBoost", new XGBoost()}, + {"ODTE", new pywrap::ODTE()} + }; + // + // Load dataset + // + auto datasetName = "wine"; + bool class_last = false; + auto [X, y] = loadDataset(datasetName, class_last); + // + // Split train/test + // + auto [train_indices, test_indices] = get_train_test_indices(X.size(1)); + auto Xtrain = X.index({ "...", train_indices }); + auto ytrain = y.index({ train_indices }); + auto Xtest = X.index({ "...", test_indices }); + auto ytest = y.index({ test_indices }); + cout << "Dataset: " << datasetName << endl; + cout << "X: " << X.sizes() << endl; + cout << "y: " << y.sizes() << endl; + cout << "Xtrain: " << Xtrain.sizes() << endl; + cout << "ytrain: " << ytrain.sizes() << endl; + cout << "Xtest : " << Xtest.sizes() << endl; + cout << "ytest : " << ytest.sizes() << endl; + // + // Train classifiers + // + for (auto& [name, clf] : classifiers) { + cout << "Training " << name << endl; + clf->fit(Xtrain, ytrain); + } + // + // Show scores + // + for (auto& [name, clf] : classifiers) { + cout << "Score " << setw(10) << name << "(Ver. " << clf->version() << "): " + << clf->score(Xtest, ytest) << endl; + } + // + // Free classifiers + // + for (auto& [name, clf] : classifiers) { + delete clf; } cout << "* End." << endl; } \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6970ec1..9c99a7e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,5 +3,6 @@ include_directories(${PyWrap_SOURCE_DIR}/lib/json/include) include_directories(${Python3_INCLUDE_DIRS}) include_directories(${TORCH_INCLUDE_DIRS}) -add_library(PyWrap SHARED PyWrap.cc STree.cc SVC.cc RandomForest.cc PyClassifier.cc) +add_library(PyWrap SHARED PyWrap.cc STree.cc ODTE.cc SVC.cc RandomForest.cc PyClassifier.cc) +#target_link_libraries(PyWrap ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy xgboost::xgboost ArffFiles) target_link_libraries(PyWrap ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy ArffFiles) \ No newline at end of file diff --git a/src/Classifier.h b/src/Classifier.h index cefff5e..8160aa9 100644 --- a/src/Classifier.h +++ b/src/Classifier.h @@ -1,13 +1,25 @@ -#ifndef CLASSIFER_H -#define CLASSIFER_H +#ifndef CLASSIFIER_H +#define CLASSIFIER_H +#include #include +#include +#include +#include namespace pywrap { class Classifier { public: Classifier() = default; virtual ~Classifier() = default; + virtual Classifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states) = 0; + virtual Classifier& fit(torch::Tensor& X, torch::Tensor& y) = 0; + virtual torch::Tensor predict(torch::Tensor& X) = 0; + virtual double score(torch::Tensor& X, torch::Tensor& y) = 0; + virtual std::string version() = 0; + virtual std::string sklearnVersion() = 0; virtual void setHyperparameters(const nlohmann::json& hyperparameters) = 0; + protected: + virtual void checkHyperparameters(const std::vector& validKeys, const nlohmann::json& hyperparameters) = 0; }; } /* namespace pywrap */ -#endif /* CLASSIFER_H */ \ No newline at end of file +#endif /* CLASSIFIER_H */ \ No newline at end of file diff --git a/src/ODTE.cc b/src/ODTE.cc new file mode 100644 index 0000000..f168f43 --- /dev/null +++ b/src/ODTE.cc @@ -0,0 +1,15 @@ +#include "ODTE.h" + +namespace pywrap { + std::string ODTE::graph() + { + return callMethodString("graph"); + } + void ODTE::setHyperparameters(const nlohmann::json& hyperparameters) + { + // Check if hyperparameters are valid + const std::vector validKeys = { "n_jobs", "n_estimators", "random_state" }; + checkHyperparameters(validKeys, hyperparameters); + this->hyperparameters = hyperparameters; + } +} /* namespace pywrap */ \ No newline at end of file diff --git a/src/ODTE.h b/src/ODTE.h new file mode 100644 index 0000000..1c90951 --- /dev/null +++ b/src/ODTE.h @@ -0,0 +1,15 @@ +#ifndef ODTE_H +#define ODTE_H +#include "nlohmann/json.hpp" +#include "PyClassifier.h" + +namespace pywrap { + class ODTE : public PyClassifier { + public: + ODTE() : PyClassifier("odte", "Odte") {}; + ~ODTE() = default; + std::string graph(); + void setHyperparameters(const nlohmann::json& hyperparameters) override; + }; +} /* namespace pywrap */ +#endif /* ODTE_H */ \ No newline at end of file diff --git a/src/PyClassifier.cc b/src/PyClassifier.cc index 312202e..9da85ab 100644 --- a/src/PyClassifier.cc +++ b/src/PyClassifier.cc @@ -31,6 +31,10 @@ namespace pywrap { { return pyWrap->version(id); } + std::string PyClassifier::sklearnVersion() + { + return pyWrap->sklearnVersion(); + } std::string PyClassifier::callMethodString(const std::string& method) { return pyWrap->callMethodString(id, method); diff --git a/src/PyClassifier.h b/src/PyClassifier.h index 39dff49..b243f68 100644 --- a/src/PyClassifier.h +++ b/src/PyClassifier.h @@ -1,5 +1,5 @@ -#ifndef PYCLASSIFER_H -#define PYCLASSIFER_H +#ifndef PYCLASSIFIER_H +#define PYCLASSIFIER_H #include "boost/python/detail/wrap_python.hpp" #include #include @@ -17,15 +17,16 @@ namespace pywrap { public: PyClassifier(const std::string& module, const std::string& className); virtual ~PyClassifier(); - PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states); - PyClassifier& fit(torch::Tensor& X, torch::Tensor& y); - torch::Tensor predict(torch::Tensor& X); - double score(torch::Tensor& X, torch::Tensor& y); - std::string version(); + PyClassifier& fit(torch::Tensor& X, torch::Tensor& y, const std::vector& features, const std::string& className, std::map>& states) override; + PyClassifier& fit(torch::Tensor& X, torch::Tensor& y) override; + torch::Tensor predict(torch::Tensor& X) override; + double score(torch::Tensor& X, torch::Tensor& y) override; + std::string version() override; + std::string sklearnVersion() override; std::string callMethodString(const std::string& method); void setHyperparameters(const nlohmann::json& hyperparameters) override; protected: - void checkHyperparameters(const std::vector& validKeys, const nlohmann::json& hyperparameters); + void checkHyperparameters(const std::vector& validKeys, const nlohmann::json& hyperparameters) override; nlohmann::json hyperparameters; private: PyWrap* pyWrap; @@ -35,4 +36,4 @@ namespace pywrap { bool fitted; }; } /* namespace pywrap */ -#endif /* PYCLASSIFER_H */ \ No newline at end of file +#endif /* PYCLASSIFIER_H */ \ No newline at end of file diff --git a/src/PyWrap.cc b/src/PyWrap.cc index decd296..2912749 100644 --- a/src/PyWrap.cc +++ b/src/PyWrap.cc @@ -42,7 +42,6 @@ namespace pywrap { if (result != moduleClassMap.end()) { return; } - std::cout << "1a" << std::endl; PyObject* module = PyImport_ImportModule(moduleName.c_str()); if (PyErr_Occurred()) { errorAbort("Couldn't import module " + moduleName); @@ -107,6 +106,13 @@ namespace pywrap { Py_XDECREF(result); return value; } + std::string PyWrap::sklearnVersion() + { + return "1.0"; + // CPyObject data = PyRun_SimpleString("import sklearn;return sklearn.__version__"); + // std::string result = PyUnicode_AsUTF8(data); + // return result; + } std::string PyWrap::version(const clfId_t id) { return callMethodString(id, "version"); diff --git a/src/PyWrap.h b/src/PyWrap.h index 13002d6..7f83c99 100644 --- a/src/PyWrap.h +++ b/src/PyWrap.h @@ -24,6 +24,7 @@ namespace pywrap { void operator=(const PyWrap&) = delete; ~PyWrap() = default; std::string callMethodString(const clfId_t id, const std::string& method); + std::string sklearnVersion(); std::string version(const clfId_t id); void setHyperparameters(const clfId_t id, const json& hyperparameters); void fit(const clfId_t id, CPyObject& X, CPyObject& y); diff --git a/src/RandomForest.cc b/src/RandomForest.cc index 196b1a7..dd0be1f 100644 --- a/src/RandomForest.cc +++ b/src/RandomForest.cc @@ -3,6 +3,6 @@ namespace pywrap { std::string RandomForest::version() { - return callMethodString("1.0"); + return sklearnVersion(); } } /* namespace pywrap */ \ No newline at end of file diff --git a/src/SVC.cc b/src/SVC.cc index 225bf0a..3245903 100644 --- a/src/SVC.cc +++ b/src/SVC.cc @@ -3,7 +3,7 @@ namespace pywrap { std::string SVC::version() { - return callMethodString("1.0"); + return sklearnVersion(); } void SVC::setHyperparameters(const nlohmann::json& hyperparameters) {