Add XGBoost Classifier
This commit is contained in:
@@ -1,16 +1,11 @@
|
|||||||
#include "XGBoost.h"
|
#include "XGBoost.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//See https ://stackoverflow.com/questions/36071672/using-xgboost-in-c
|
//See https ://stackoverflow.com/questions/36071672/using-xgboost-in-c
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
namespace pywrap {
|
namespace pywrap {
|
||||||
|
XGBoost::XGBoost() : PyClassifier("xgboost", "XGBClassifier")
|
||||||
|
{
|
||||||
|
validHyperparameters = { "tree_method", "early_stopping_rounds", "n_jobs" };
|
||||||
|
}
|
||||||
std::string XGBoost::version()
|
std::string XGBoost::version()
|
||||||
{
|
{
|
||||||
return callMethodString("1.0");
|
return callMethodString("1.0");
|
||||||
|
@@ -5,7 +5,7 @@
|
|||||||
namespace pywrap {
|
namespace pywrap {
|
||||||
class XGBoost : public PyClassifier {
|
class XGBoost : public PyClassifier {
|
||||||
public:
|
public:
|
||||||
XGBoost() : PyClassifier("xgboost", "XGBClassifier") {};
|
XGBoost();
|
||||||
~XGBoost() = default;
|
~XGBoost() = default;
|
||||||
std::string version();
|
std::string version();
|
||||||
};
|
};
|
||||||
|
@@ -8,8 +8,10 @@
|
|||||||
#include "STree.h"
|
#include "STree.h"
|
||||||
#include "SVC.h"
|
#include "SVC.h"
|
||||||
#include "RandomForest.h"
|
#include "RandomForest.h"
|
||||||
|
#include "XGBoost.h"
|
||||||
#include "ODTE.h"
|
#include "ODTE.h"
|
||||||
#include "TestUtils.h"
|
#include "TestUtils.h"
|
||||||
|
#include <nlohmann/json.hpp>
|
||||||
|
|
||||||
TEST_CASE("Test Python Classifiers score", "[PyClassifiers]")
|
TEST_CASE("Test Python Classifiers score", "[PyClassifiers]")
|
||||||
{
|
{
|
||||||
@@ -71,4 +73,14 @@ TEST_CASE("Get num features & num edges", "[PyClassifiers]")
|
|||||||
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||||
REQUIRE(clf.getNumberOfNodes() == 10);
|
REQUIRE(clf.getNumberOfNodes() == 10);
|
||||||
REQUIRE(clf.getNumberOfEdges() == 10);
|
REQUIRE(clf.getNumberOfEdges() == 10);
|
||||||
|
}
|
||||||
|
TEST_CASE("XGBoost", "[PyClassifiers]")
|
||||||
|
{
|
||||||
|
auto raw = RawDatasets("iris", true);
|
||||||
|
auto clf = pywrap::XGBoost();
|
||||||
|
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
|
||||||
|
nlohmann::json hyperparameters = { "n_jobs=1" };
|
||||||
|
clf.setHyperparameters(hyperparameters);
|
||||||
|
auto score = clf.score(raw.Xt, raw.yt);
|
||||||
|
REQUIRE(score == Catch::Approx(0.98).epsilon(raw.epsilon));
|
||||||
}
|
}
|
Reference in New Issue
Block a user