diff --git a/pyclfs/ODTE.cc b/pyclfs/ODTE.cc index 11b1433..09ea99b 100644 --- a/pyclfs/ODTE.cc +++ b/pyclfs/ODTE.cc @@ -3,7 +3,7 @@ namespace pywrap { ODTE::ODTE() : PyClassifier("odte", "Odte") { - validHyperparameters = { "n_jobs", "n_estimators", "random_state" }; + validHyperparameters = { "n_jobs", "n_estimators", "random_state", "be_hyperparams" }; } int ODTE::getNumberOfNodes() const { diff --git a/tests/TestPythonClassifiers.cc b/tests/TestPythonClassifiers.cc index 0a4cf6d..9c42f67 100644 --- a/tests/TestPythonClassifiers.cc +++ b/tests/TestPythonClassifiers.cc @@ -114,4 +114,23 @@ TEST_CASE("XGBoost", "[PyClassifiers]") clf.setHyperparameters(hyperparameters); auto score = clf.score(raw.Xt, raw.yt); REQUIRE(score == Catch::Approx(0.98).epsilon(raw.epsilon)); +} +TEST_CASE("XGBoost predict proba", "[PyClassifiers]") +{ + auto raw = RawDatasets("iris", true); + auto clf = pywrap::XGBoost(); + clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); + nlohmann::json hyperparameters = { "n_jobs=1" }; + clf.setHyperparameters(hyperparameters); + auto predict = clf.predict(raw.Xt); + // for (int row = 0; row < predict.size(0); row++) { + // auto sum = 0.0; + // for (int col = 0; col < predict.size(1); col++) { + // std::cout << std::setw(12) << std::setprecision(10) << predict[row][col].item() << " "; + // sum += predict[row][col].item(); + // } + // std::cout << std::endl; + // // REQUIRE(sum == Catch::Approx(1.0).epsilon(raw.epsilon)); + // } + std::cout << predict << std::endl; } \ No newline at end of file