Add XGBoost & RandomForest

This commit is contained in:
2023-11-07 17:41:37 +01:00
parent e0481dfa44
commit 1f46fc6c24
9 changed files with 60 additions and 23 deletions

View File

@@ -2,7 +2,7 @@ include_directories(${PyWrap_SOURCE_DIR}/lib/Files)
include_directories(${Python3_INCLUDE_DIRS})
include_directories(${TORCH_INCLUDE_DIRS})
add_executable(main main.cc STree.cc SVC.cc PyClassifier.cc PyWrap.cc)
add_executable(main main.cc STree.cc SVC.cc RandomForest.cc PyClassifier.cc PyWrap.cc)
add_executable(example example.cpp PyWrap.cc)
target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy ArffFiles)

View File

@@ -23,8 +23,6 @@ namespace pywrap {
Py_Finalize();
}
};
class CPyObject {
private:
PyObject* p;
@@ -36,23 +34,18 @@ namespace pywrap {
CPyObject(PyObject* _p) : p(_p)
{
}
~CPyObject()
{
Release();
}
PyObject* getObject()
{
return p;
}
PyObject* setObject(PyObject* _p)
{
return (p = _p);
}
PyObject* AddRef()
{
if (p) {
@@ -60,7 +53,6 @@ namespace pywrap {
}
return p;
}
void Release()
{
if (p) {
@@ -69,33 +61,27 @@ namespace pywrap {
p = NULL;
}
PyObject* operator ->()
{
return p;
}
bool is()
{
return p ? true : false;
}
operator PyObject* ()
{
return p;
}
PyObject* operator = (PyObject* pp)
{
p = pp;
return p;
}
operator bool()
{
return p ? true : false;
}
};
} /* namespace pywrap */
#endif

View File

@@ -143,6 +143,7 @@ namespace pywrap {
RemoveInstance();
exit(1);
}
Py_INCREF(result);
return result; // Caller must free this object
}
double PyWrap::score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)

8
src/RandomForest.cc Normal file
View File

@@ -0,0 +1,8 @@
#include "RandomForest.h"
namespace pywrap {
std::string RandomForest::version()
{
return callMethodString("1.0");
}
} /* namespace pywrap */

13
src/RandomForest.h Normal file
View File

@@ -0,0 +1,13 @@
#ifndef RANDOMFOREST_H
#define RANDOMFOREST_H
#include "PyClassifier.h"
namespace pywrap {
class RandomForest : public PyClassifier {
public:
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier") {};
~RandomForest() = default;
std::string version();
};
} /* namespace pywrap */
#endif /* RANDOMFOREST_H */

View File

@@ -3,6 +3,6 @@
namespace pywrap {
std::string SVC::version()
{
return callMethodString("_repr_html_");
return callMethodString("1.0");
}
} /* namespace pywrap */

8
src/XGBoost.cc Normal file
View File

@@ -0,0 +1,8 @@
#include "XGBoost.h"
namespace pywrap {
std::string XGBoost::version()
{
return callMethodString("1.0");
}
} /* namespace pywrap */

13
src/XGBoost.h Normal file
View File

@@ -0,0 +1,13 @@
#ifndef XGBOOST_H
#define XGBOOST_H
#include "PyClassifier.h"
namespace pywrap {
class XGBoost : public PyClassifier {
public:
XGBoost() : PyClassifier("xgboost", "XGBClassifier") {};
~XGBoost() = default;
std::string version();
};
} /* namespace pywrap */
#endif /* XGBOOST_H */

View File

@@ -7,6 +7,7 @@
#include <tuple>
#include "STree.h"
#include "SVC.h"
#include "RandomForest.h"
using namespace std;
using namespace torch;
@@ -44,25 +45,32 @@ int main(int argc, char* argv[])
{
cout << "* Begin." << endl;
{
auto [X, y, features, className, states] = loadDataset("wine", false);
auto datasetName = "iris";
bool class_last = true;
auto [X, y, features, className, states] = loadDataset(datasetName, class_last);
cout << "Dataset: " << datasetName << endl;
cout << "X: " << X.sizes() << endl;
cout << "y: " << y.sizes() << endl;
auto clf = pywrap::STree();
cout << "STree Version: " << clf.version() << endl;
if (true) {
auto svc = pywrap::SVC();
svc.fit(X, y, features, className, states);
cout << "SVC Score: " << svc.score(X, y) << endl;
}
auto svc = pywrap::SVC();
svc.fit(X, y, features, className, states);
cout << "Graph: " << endl << clf.graph() << endl;
clf.fit(X, y, features, className, states);
cout << "STree Score: " << clf.score(X, y) << endl;
auto prediction = clf.predict(X);
cout << "Prediction: " << endl << "{";
for (int i = 0; i < prediction.size(0); ++i) {
cout << prediction[i].item<int>() << ", ";
}
cout << "}" << endl;
auto rf = pywrap::RandomForest();
rf.fit(X, y, features, className, states);
auto xg = pywrap::RandomForest();
xg.fit(X, y, features, className, states);
cout << "STree Score ......: " << clf.score(X, y) << endl;
cout << "RandomForest Score: " << rf.score(X, y) << endl;
cout << "SVC Score ........: " << svc.score(X, y) << endl;
cout << "XGBoost Score ....: " << xg.score(X, y) << endl;
}
cout << "* End." << endl;
}