Add XGBoost & RandomForest
This commit is contained in:
@@ -2,7 +2,7 @@ include_directories(${PyWrap_SOURCE_DIR}/lib/Files)
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
include_directories(${TORCH_INCLUDE_DIRS})
|
||||
|
||||
add_executable(main main.cc STree.cc SVC.cc PyClassifier.cc PyWrap.cc)
|
||||
add_executable(main main.cc STree.cc SVC.cc RandomForest.cc PyClassifier.cc PyWrap.cc)
|
||||
add_executable(example example.cpp PyWrap.cc)
|
||||
|
||||
target_link_libraries(main ${Python3_LIBRARIES} "${TORCH_LIBRARIES}" ${LIBTORCH_PYTHON} Boost::boost Boost::python Boost::numpy ArffFiles)
|
||||
|
@@ -23,8 +23,6 @@ namespace pywrap {
|
||||
Py_Finalize();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class CPyObject {
|
||||
private:
|
||||
PyObject* p;
|
||||
@@ -36,23 +34,18 @@ namespace pywrap {
|
||||
CPyObject(PyObject* _p) : p(_p)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
~CPyObject()
|
||||
{
|
||||
Release();
|
||||
}
|
||||
|
||||
PyObject* getObject()
|
||||
{
|
||||
return p;
|
||||
}
|
||||
|
||||
PyObject* setObject(PyObject* _p)
|
||||
{
|
||||
return (p = _p);
|
||||
}
|
||||
|
||||
PyObject* AddRef()
|
||||
{
|
||||
if (p) {
|
||||
@@ -60,7 +53,6 @@ namespace pywrap {
|
||||
}
|
||||
return p;
|
||||
}
|
||||
|
||||
void Release()
|
||||
{
|
||||
if (p) {
|
||||
@@ -69,33 +61,27 @@ namespace pywrap {
|
||||
|
||||
p = NULL;
|
||||
}
|
||||
|
||||
PyObject* operator ->()
|
||||
{
|
||||
return p;
|
||||
}
|
||||
|
||||
bool is()
|
||||
{
|
||||
return p ? true : false;
|
||||
}
|
||||
|
||||
operator PyObject* ()
|
||||
{
|
||||
return p;
|
||||
}
|
||||
|
||||
PyObject* operator = (PyObject* pp)
|
||||
{
|
||||
p = pp;
|
||||
return p;
|
||||
}
|
||||
|
||||
operator bool()
|
||||
{
|
||||
return p ? true : false;
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace pywrap */
|
||||
#endif
|
@@ -143,6 +143,7 @@ namespace pywrap {
|
||||
RemoveInstance();
|
||||
exit(1);
|
||||
}
|
||||
Py_INCREF(result);
|
||||
return result; // Caller must free this object
|
||||
}
|
||||
double PyWrap::score(const std::string& moduleName, const std::string& className, CPyObject& X, CPyObject& y)
|
||||
|
8
src/RandomForest.cc
Normal file
8
src/RandomForest.cc
Normal file
@@ -0,0 +1,8 @@
|
||||
#include "RandomForest.h"
|
||||
|
||||
namespace pywrap {
|
||||
std::string RandomForest::version()
|
||||
{
|
||||
return callMethodString("1.0");
|
||||
}
|
||||
} /* namespace pywrap */
|
13
src/RandomForest.h
Normal file
13
src/RandomForest.h
Normal file
@@ -0,0 +1,13 @@
|
||||
#ifndef RANDOMFOREST_H
|
||||
#define RANDOMFOREST_H
|
||||
#include "PyClassifier.h"
|
||||
|
||||
namespace pywrap {
|
||||
class RandomForest : public PyClassifier {
|
||||
public:
|
||||
RandomForest() : PyClassifier("sklearn.ensemble", "RandomForestClassifier") {};
|
||||
~RandomForest() = default;
|
||||
std::string version();
|
||||
};
|
||||
} /* namespace pywrap */
|
||||
#endif /* RANDOMFOREST_H */
|
@@ -3,6 +3,6 @@
|
||||
namespace pywrap {
|
||||
std::string SVC::version()
|
||||
{
|
||||
return callMethodString("_repr_html_");
|
||||
return callMethodString("1.0");
|
||||
}
|
||||
} /* namespace pywrap */
|
8
src/XGBoost.cc
Normal file
8
src/XGBoost.cc
Normal file
@@ -0,0 +1,8 @@
|
||||
#include "XGBoost.h"
|
||||
|
||||
namespace pywrap {
|
||||
std::string XGBoost::version()
|
||||
{
|
||||
return callMethodString("1.0");
|
||||
}
|
||||
} /* namespace pywrap */
|
13
src/XGBoost.h
Normal file
13
src/XGBoost.h
Normal file
@@ -0,0 +1,13 @@
|
||||
#ifndef XGBOOST_H
|
||||
#define XGBOOST_H
|
||||
#include "PyClassifier.h"
|
||||
|
||||
namespace pywrap {
|
||||
class XGBoost : public PyClassifier {
|
||||
public:
|
||||
XGBoost() : PyClassifier("xgboost", "XGBClassifier") {};
|
||||
~XGBoost() = default;
|
||||
std::string version();
|
||||
};
|
||||
} /* namespace pywrap */
|
||||
#endif /* XGBOOST_H */
|
22
src/main.cc
22
src/main.cc
@@ -7,6 +7,7 @@
|
||||
#include <tuple>
|
||||
#include "STree.h"
|
||||
#include "SVC.h"
|
||||
#include "RandomForest.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace torch;
|
||||
@@ -44,25 +45,32 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
cout << "* Begin." << endl;
|
||||
{
|
||||
auto [X, y, features, className, states] = loadDataset("wine", false);
|
||||
auto datasetName = "iris";
|
||||
bool class_last = true;
|
||||
auto [X, y, features, className, states] = loadDataset(datasetName, class_last);
|
||||
cout << "Dataset: " << datasetName << endl;
|
||||
cout << "X: " << X.sizes() << endl;
|
||||
cout << "y: " << y.sizes() << endl;
|
||||
auto clf = pywrap::STree();
|
||||
cout << "STree Version: " << clf.version() << endl;
|
||||
if (true) {
|
||||
auto svc = pywrap::SVC();
|
||||
svc.fit(X, y, features, className, states);
|
||||
cout << "SVC Score: " << svc.score(X, y) << endl;
|
||||
}
|
||||
auto svc = pywrap::SVC();
|
||||
svc.fit(X, y, features, className, states);
|
||||
cout << "Graph: " << endl << clf.graph() << endl;
|
||||
clf.fit(X, y, features, className, states);
|
||||
cout << "STree Score: " << clf.score(X, y) << endl;
|
||||
auto prediction = clf.predict(X);
|
||||
cout << "Prediction: " << endl << "{";
|
||||
for (int i = 0; i < prediction.size(0); ++i) {
|
||||
cout << prediction[i].item<int>() << ", ";
|
||||
}
|
||||
cout << "}" << endl;
|
||||
auto rf = pywrap::RandomForest();
|
||||
rf.fit(X, y, features, className, states);
|
||||
auto xg = pywrap::RandomForest();
|
||||
xg.fit(X, y, features, className, states);
|
||||
cout << "STree Score ......: " << clf.score(X, y) << endl;
|
||||
cout << "RandomForest Score: " << rf.score(X, y) << endl;
|
||||
cout << "SVC Score ........: " << svc.score(X, y) << endl;
|
||||
cout << "XGBoost Score ....: " << xg.score(X, y) << endl;
|
||||
}
|
||||
cout << "* End." << endl;
|
||||
}
|
Reference in New Issue
Block a user