Add XGBoost & RandomForest

This commit is contained in:
2023-11-07 17:41:37 +01:00
parent e0481dfa44
commit 1f46fc6c24
9 changed files with 60 additions and 23 deletions

View File

@@ -2,7 +2,7 @@ include_directories(${PyWrap_SOURCE_DIR}/lib/Files)
include_directories(${Python3_INCLUDE_DIRS}) include_directories(${Python3_INCLUDE_DIRS})
include_directories(${TORCH_INCLUDE_DIRS}) include_directories(${TORCH_INCLUDE_DIRS})
add_executable(main main.cc STree.cc SVC.cc PyClassifier.cc PyWrap.cc) add_executable(main main.cc STree.cc SVC.cc RandomForest.cc PyClassifier.cc PyWrap.cc)
add_executable(example example.cpp PyWrap.cc) add_executable(example example.cpp PyWrap.cc)
target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy ArffFiles) target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy ArffFiles)

View File

@@ -23,8 +23,6 @@ namespace pywrap {
Py_Finalize(); Py_Finalize();
} }
}; };
class CPyObject { class CPyObject {
private: private:
PyObject* p; PyObject* p;
@@ -36,23 +34,18 @@ namespace pywrap {
CPyObject(PyObject* _p) : p(_p) CPyObject(PyObject* _p) : p(_p)
{ {
} }
~CPyObject() ~CPyObject()
{ {
Release(); Release();
} }
PyObject* getObject() PyObject* getObject()
{ {
return p; return p;
} }
PyObject* setObject(PyObject* _p) PyObject* setObject(PyObject* _p)
{ {
return (p = _p); return (p = _p);
} }
PyObject* AddRef() PyObject* AddRef()
{ {
if (p) { if (p) {
@@ -60,7 +53,6 @@ namespace pywrap {
} }
return p; return p;
} }
void Release() void Release()
{ {
if (p) { if (p) {
@@ -69,33 +61,27 @@ namespace pywrap {
p = NULL; p = NULL;
} }
PyObject* operator ->() PyObject* operator ->()
{ {
return p; return p;
} }
bool is() bool is()
{ {
return p ? true : false; return p ? true : false;
} }
operator PyObject* () operator PyObject* ()
{ {
return p; return p;
} }
PyObject* operator = (PyObject* pp) PyObject* operator = (PyObject* pp)
{ {
p = pp; p = pp;
return p; return p;
} }
operator bool() operator bool()
{ {
return p ? true : false; return p ? true : false;
} }
}; };
} /* namespace pywrap */ } /* namespace pywrap */
#endif #endif

View File

@@ -143,6 +143,7 @@ namespace pywrap {
RemoveInstance(); RemoveInstance();
exit(1); exit(1);
} }
Py_INCREF(result);
return result; // Caller must free this object return result; // Caller must free this object
} }
double PyWrap::score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y) double PyWrap::score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)

8
src/RandomForest.cc Normal file
View File

@@ -0,0 +1,8 @@
#include "RandomForest.h"
namespace pywrap {
std::string RandomForest::version()
{
return callMethodString("1.0");
}
} /* namespace pywrap */

13
src/RandomForest.h Normal file
View File

@@ -0,0 +1,13 @@
#ifndef RANDOMFOREST_H
#define RANDOMFOREST_H
#include "PyClassifier.h"
namespace pywrap {
class RandomForest : public PyClassifier {
public:
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier") {};
~RandomForest() = default;
std::string version();
};
} /* namespace pywrap */
#endif /* RANDOMFOREST_H */

View File

@@ -3,6 +3,6 @@
namespace pywrap { namespace pywrap {
std::string SVC::version() std::string SVC::version()
{ {
return callMethodString("_repr_html_"); return callMethodString("1.0");
} }
} /* namespace pywrap */ } /* namespace pywrap */

8
src/XGBoost.cc Normal file
View File

@@ -0,0 +1,8 @@
#include "XGBoost.h"
namespace pywrap {
std::string XGBoost::version()
{
return callMethodString("1.0");
}
} /* namespace pywrap */

13
src/XGBoost.h Normal file
View File

@@ -0,0 +1,13 @@
#ifndef XGBOOST_H
#define XGBOOST_H
#include "PyClassifier.h"
namespace pywrap {
class XGBoost : public PyClassifier {
public:
XGBoost() : PyClassifier("xgboost", "XGBClassifier") {};
~XGBoost() = default;
std::string version();
};
} /* namespace pywrap */
#endif /* XGBOOST_H */

View File

@@ -7,6 +7,7 @@
#include <tuple> #include <tuple>
#include "STree.h" #include "STree.h"
#include "SVC.h" #include "SVC.h"
#include "RandomForest.h"
using namespace std; using namespace std;
using namespace torch; using namespace torch;
@@ -44,25 +45,32 @@ int main(int argc, char* argv[])
{ {
cout << "* Begin." << endl; cout << "* Begin." << endl;
{ {
auto [X, y, features, className, states] = loadDataset("wine", false); auto datasetName = "iris";
bool class_last = true;
auto [X, y, features, className, states] = loadDataset(datasetName, class_last);
cout << "Dataset: " << datasetName << endl;
cout << "X: " << X.sizes() << endl; cout << "X: " << X.sizes() << endl;
cout << "y: " << y.sizes() << endl; cout << "y: " << y.sizes() << endl;
auto clf = pywrap::STree(); auto clf = pywrap::STree();
cout << "STree Version: " << clf.version() << endl; cout << "STree Version: " << clf.version() << endl;
if (true) { auto svc = pywrap::SVC();
auto svc = pywrap::SVC(); svc.fit(X, y, features, className, states);
svc.fit(X, y, features, className, states);
cout << "SVC Score: " << svc.score(X, y) << endl;
}
cout << "Graph: " << endl << clf.graph() << endl; cout << "Graph: " << endl << clf.graph() << endl;
clf.fit(X, y, features, className, states); clf.fit(X, y, features, className, states);
cout << "STree Score: " << clf.score(X, y) << endl;
auto prediction = clf.predict(X); auto prediction = clf.predict(X);
cout << "Prediction: " << endl << "{"; cout << "Prediction: " << endl << "{";
for (int i = 0; i < prediction.size(0); ++i) { for (int i = 0; i < prediction.size(0); ++i) {
cout << prediction[i].item<int>() << ", "; cout << prediction[i].item<int>() << ", ";
} }
cout << "}" << endl; cout << "}" << endl;
auto rf = pywrap::RandomForest();
rf.fit(X, y, features, className, states);
auto xg = pywrap::RandomForest();
xg.fit(X, y, features, className, states);
cout << "STree Score ......: " << clf.score(X, y) << endl;
cout << "RandomForest Score: " << rf.score(X, y) << endl;
cout << "SVC Score ........: " << svc.score(X, y) << endl;
cout << "XGBoost Score ....: " << xg.score(X, y) << endl;
} }
cout << "* End." << endl; cout << "* End." << endl;
} }