Add XGBoost Classifier

This commit is contained in:
2024-01-16 17:06:00 +01:00
parent f46f6dcbb2
commit 071f7167e6
3 changed files with 17 additions and 10 deletions

View File

@@ -1,16 +1,11 @@
#include "XGBoost.h"
//See https ://stackoverflow.com/questions/36071672/using-xgboost-in-c
namespace pywrap {
XGBoost::XGBoost() : PyClassifier("xgboost", "XGBClassifier")
{
validHyperparameters = { "tree_method", "early_stopping_rounds", "n_jobs" };
}
std::string XGBoost::version()
{
return callMethodString("1.0");

View File

@@ -5,7 +5,7 @@
namespace pywrap {
class XGBoost : public PyClassifier {
public:
XGBoost() : PyClassifier("xgboost", "XGBClassifier") {};
XGBoost();
~XGBoost() = default;
std::string version();
};

View File

@@ -8,8 +8,10 @@
#include "STree.h"
#include "SVC.h"
#include "RandomForest.h"
#include "XGBoost.h"
#include "ODTE.h"
#include "TestUtils.h"
#include <nlohmann/json.hpp>
TEST_CASE("Test Python Classifiers score", "[PyClassifiers]")
{
@@ -72,3 +74,13 @@ TEST_CASE("Get num features & num edges", "[PyClassifiers]")
REQUIRE(clf.getNumberOfNodes() == 10);
REQUIRE(clf.getNumberOfEdges() == 10);
}
TEST_CASE("XGBoost", "[PyClassifiers]")
{
auto raw = RawDatasets("iris", true);
auto clf = pywrap::XGBoost();
clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest);
nlohmann::json hyperparameters = { "n_jobs=1" };
clf.setHyperparameters(hyperparameters);
auto score = clf.score(raw.Xt, raw.yt);
REQUIRE(score == Catch::Approx(0.98).epsilon(raw.epsilon));
}