From 4defb279b01b62958970b6ecb367ba4e6d60dcd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 14 Jul 2024 17:56:35 +0200 Subject: [PATCH] Enable XGBoost test --- tests/TestPythonClassifiers.cc | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/TestPythonClassifiers.cc b/tests/TestPythonClassifiers.cc index f2079d9..0a4cf6d 100644 --- a/tests/TestPythonClassifiers.cc +++ b/tests/TestPythonClassifiers.cc @@ -105,13 +105,13 @@ TEST_CASE("Predict with non_discretized dataset and comparing to predict_proba", auto accuracy = right / static_cast(predictions.size(0)); REQUIRE(accuracy == Catch::Approx(1.0f).epsilon(raw.epsilon)); } -// TEST_CASE("XGBoost", "[PyClassifiers]") -// { -// auto raw = RawDatasets("iris", true); -// auto clf = pywrap::XGBoost(); -// clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); -// nlohmann::json hyperparameters = { "n_jobs=1" }; -// clf.setHyperparameters(hyperparameters); -// auto score = clf.score(raw.Xt, raw.yt); -// REQUIRE(score == Catch::Approx(0.98).epsilon(raw.epsilon)); -// } \ No newline at end of file +TEST_CASE("XGBoost", "[PyClassifiers]") +{ + auto raw = RawDatasets("iris", true); + auto clf = pywrap::XGBoost(); + clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); + nlohmann::json hyperparameters = { "n_jobs=1" }; + clf.setHyperparameters(hyperparameters); + auto score = clf.score(raw.Xt, raw.yt); + REQUIRE(score == Catch::Approx(0.98).epsilon(raw.epsilon)); +} \ No newline at end of file