Add XGBoost & RandomForest
This commit is contained in:
@@ -2,7 +2,7 @@ include_directories(${PyWrap_SOURCE_DIR}/lib/Files)
|
|||||||
include_directories(${Python3_INCLUDE_DIRS})
|
include_directories(${Python3_INCLUDE_DIRS})
|
||||||
include_directories(${TORCH_INCLUDE_DIRS})
|
include_directories(${TORCH_INCLUDE_DIRS})
|
||||||
|
|
||||||
add_executable(main main.cc STree.cc SVC.cc PyClassifier.cc PyWrap.cc)
|
add_executable(main main.cc STree.cc SVC.cc RandomForest.cc PyClassifier.cc PyWrap.cc)
|
||||||
add_executable(example example.cpp PyWrap.cc)
|
add_executable(example example.cpp PyWrap.cc)
|
||||||
|
|
||||||
target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy ArffFiles)
|
target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy ArffFiles)
|
||||||
|
@@ -23,8 +23,6 @@ namespace pywrap {
|
|||||||
Py_Finalize();
|
Py_Finalize();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class CPyObject {
|
class CPyObject {
|
||||||
private:
|
private:
|
||||||
PyObject* p;
|
PyObject* p;
|
||||||
@@ -36,23 +34,18 @@ namespace pywrap {
|
|||||||
CPyObject(PyObject* _p) : p(_p)
|
CPyObject(PyObject* _p) : p(_p)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
~CPyObject()
|
~CPyObject()
|
||||||
{
|
{
|
||||||
Release();
|
Release();
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* getObject()
|
PyObject* getObject()
|
||||||
{
|
{
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* setObject(PyObject* _p)
|
PyObject* setObject(PyObject* _p)
|
||||||
{
|
{
|
||||||
return (p = _p);
|
return (p = _p);
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* AddRef()
|
PyObject* AddRef()
|
||||||
{
|
{
|
||||||
if (p) {
|
if (p) {
|
||||||
@@ -60,7 +53,6 @@ namespace pywrap {
|
|||||||
}
|
}
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Release()
|
void Release()
|
||||||
{
|
{
|
||||||
if (p) {
|
if (p) {
|
||||||
@@ -69,33 +61,27 @@ namespace pywrap {
|
|||||||
|
|
||||||
p = NULL;
|
p = NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* operator ->()
|
PyObject* operator ->()
|
||||||
{
|
{
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is()
|
bool is()
|
||||||
{
|
{
|
||||||
return p ? true : false;
|
return p ? true : false;
|
||||||
}
|
}
|
||||||
|
|
||||||
operator PyObject* ()
|
operator PyObject* ()
|
||||||
{
|
{
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
PyObject* operator = (PyObject* pp)
|
PyObject* operator = (PyObject* pp)
|
||||||
{
|
{
|
||||||
p = pp;
|
p = pp;
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
operator bool()
|
operator bool()
|
||||||
{
|
{
|
||||||
return p ? true : false;
|
return p ? true : false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} /* namespace pywrap */
|
} /* namespace pywrap */
|
||||||
#endif
|
#endif
|
@@ -143,6 +143,7 @@ namespace pywrap {
|
|||||||
RemoveInstance();
|
RemoveInstance();
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
Py_INCREF(result);
|
||||||
return result; // Caller must free this object
|
return result; // Caller must free this object
|
||||||
}
|
}
|
||||||
double PyWrap::score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
|
double PyWrap::score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
|
||||||
|
8
src/RandomForest.cc
Normal file
8
src/RandomForest.cc
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#include "RandomForest.h"
|
||||||
|
|
||||||
|
namespace pywrap {
|
||||||
|
std::string RandomForest::version()
|
||||||
|
{
|
||||||
|
return callMethodString("1.0");
|
||||||
|
}
|
||||||
|
} /* namespace pywrap */
|
13
src/RandomForest.h
Normal file
13
src/RandomForest.h
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#ifndef RANDOMFOREST_H
|
||||||
|
#define RANDOMFOREST_H
|
||||||
|
#include "PyClassifier.h"
|
||||||
|
|
||||||
|
namespace pywrap {
|
||||||
|
class RandomForest : public PyClassifier {
|
||||||
|
public:
|
||||||
|
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier") {};
|
||||||
|
~RandomForest() = default;
|
||||||
|
std::string version();
|
||||||
|
};
|
||||||
|
} /* namespace pywrap */
|
||||||
|
#endif /* RANDOMFOREST_H */
|
@@ -3,6 +3,6 @@
|
|||||||
namespace pywrap {
|
namespace pywrap {
|
||||||
std::string SVC::version()
|
std::string SVC::version()
|
||||||
{
|
{
|
||||||
return callMethodString("_repr_html_");
|
return callMethodString("1.0");
|
||||||
}
|
}
|
||||||
} /* namespace pywrap */
|
} /* namespace pywrap */
|
8
src/XGBoost.cc
Normal file
8
src/XGBoost.cc
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#include "XGBoost.h"
|
||||||
|
|
||||||
|
namespace pywrap {
|
||||||
|
std::string XGBoost::version()
|
||||||
|
{
|
||||||
|
return callMethodString("1.0");
|
||||||
|
}
|
||||||
|
} /* namespace pywrap */
|
13
src/XGBoost.h
Normal file
13
src/XGBoost.h
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#ifndef XGBOOST_H
|
||||||
|
#define XGBOOST_H
|
||||||
|
#include "PyClassifier.h"
|
||||||
|
|
||||||
|
namespace pywrap {
|
||||||
|
class XGBoost : public PyClassifier {
|
||||||
|
public:
|
||||||
|
XGBoost() : PyClassifier("xgboost", "XGBClassifier") {};
|
||||||
|
~XGBoost() = default;
|
||||||
|
std::string version();
|
||||||
|
};
|
||||||
|
} /* namespace pywrap */
|
||||||
|
#endif /* XGBOOST_H */
|
22
src/main.cc
22
src/main.cc
@@ -7,6 +7,7 @@
|
|||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include "STree.h"
|
#include "STree.h"
|
||||||
#include "SVC.h"
|
#include "SVC.h"
|
||||||
|
#include "RandomForest.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace torch;
|
using namespace torch;
|
||||||
@@ -44,25 +45,32 @@ int main(int argc, char* argv[])
|
|||||||
{
|
{
|
||||||
cout << "* Begin." << endl;
|
cout << "* Begin." << endl;
|
||||||
{
|
{
|
||||||
auto [X, y, features, className, states] = loadDataset("wine", false);
|
auto datasetName = "iris";
|
||||||
|
bool class_last = true;
|
||||||
|
auto [X, y, features, className, states] = loadDataset(datasetName, class_last);
|
||||||
|
cout << "Dataset: " << datasetName << endl;
|
||||||
cout << "X: " << X.sizes() << endl;
|
cout << "X: " << X.sizes() << endl;
|
||||||
cout << "y: " << y.sizes() << endl;
|
cout << "y: " << y.sizes() << endl;
|
||||||
auto clf = pywrap::STree();
|
auto clf = pywrap::STree();
|
||||||
cout << "STree Version: " << clf.version() << endl;
|
cout << "STree Version: " << clf.version() << endl;
|
||||||
if (true) {
|
auto svc = pywrap::SVC();
|
||||||
auto svc = pywrap::SVC();
|
svc.fit(X, y, features, className, states);
|
||||||
svc.fit(X, y, features, className, states);
|
|
||||||
cout << "SVC Score: " << svc.score(X, y) << endl;
|
|
||||||
}
|
|
||||||
cout << "Graph: " << endl << clf.graph() << endl;
|
cout << "Graph: " << endl << clf.graph() << endl;
|
||||||
clf.fit(X, y, features, className, states);
|
clf.fit(X, y, features, className, states);
|
||||||
cout << "STree Score: " << clf.score(X, y) << endl;
|
|
||||||
auto prediction = clf.predict(X);
|
auto prediction = clf.predict(X);
|
||||||
cout << "Prediction: " << endl << "{";
|
cout << "Prediction: " << endl << "{";
|
||||||
for (int i = 0; i < prediction.size(0); ++i) {
|
for (int i = 0; i < prediction.size(0); ++i) {
|
||||||
cout << prediction[i].item<int>() << ", ";
|
cout << prediction[i].item<int>() << ", ";
|
||||||
}
|
}
|
||||||
cout << "}" << endl;
|
cout << "}" << endl;
|
||||||
|
auto rf = pywrap::RandomForest();
|
||||||
|
rf.fit(X, y, features, className, states);
|
||||||
|
auto xg = pywrap::RandomForest();
|
||||||
|
xg.fit(X, y, features, className, states);
|
||||||
|
cout << "STree Score ......: " << clf.score(X, y) << endl;
|
||||||
|
cout << "RandomForest Score: " << rf.score(X, y) << endl;
|
||||||
|
cout << "SVC Score ........: " << svc.score(X, y) << endl;
|
||||||
|
cout << "XGBoost Score ....: " << xg.score(X, y) << endl;
|
||||||
}
|
}
|
||||||
cout << "* End." << endl;
|
cout << "* End." << endl;
|
||||||
}
|
}
|
Reference in New Issue
Block a user