diff --git a/pyclfs/PyClassifier.cc b/pyclfs/PyClassifier.cc index 359cae1..f211be8 100644 --- a/pyclfs/PyClassifier.cc +++ b/pyclfs/PyClassifier.cc @@ -93,11 +93,19 @@ namespace pywrap { PyErr_Print(); throw std::runtime_error("Error creating object for predict in " + module + " and class " + className); } - int* data = reinterpret_cast(prediction.get_data()); - std::vector vPrediction(data, data + prediction.shape(0)); - auto resultTensor = torch::tensor(vPrediction, torch::kInt32); - Py_XDECREF(incoming); - return resultTensor; + if (xgboost) { + long* data = reinterpret_cast(prediction.get_data()); + std::vector vPrediction(data, data + prediction.shape(0)); + auto resultTensor = torch::tensor(vPrediction, torch::kInt32); + Py_XDECREF(incoming); + return resultTensor; + } else { + int* data = reinterpret_cast(prediction.get_data()); + std::vector vPrediction(data, data + prediction.shape(0)); + auto resultTensor = torch::tensor(vPrediction, torch::kInt32); + Py_XDECREF(incoming); + return resultTensor; + } } torch::Tensor PyClassifier::predict_proba(torch::Tensor& X) { @@ -118,11 +126,19 @@ namespace pywrap { PyErr_Print(); throw std::runtime_error("Error creating object for predict_proba in " + module + " and class " + className); } - double* data = reinterpret_cast(prediction.get_data()); - std::vector vPrediction(data, data + prediction.shape(0) * prediction.shape(1)); - auto resultTensor = torch::tensor(vPrediction, torch::kFloat64).reshape({ prediction.shape(0), prediction.shape(1) }); - Py_XDECREF(incoming); - return resultTensor; + if (xgboost) { + float* data = reinterpret_cast(prediction.get_data()); + std::vector vPrediction(data, data + prediction.shape(0) * prediction.shape(1)); + auto resultTensor = torch::tensor(vPrediction, torch::kFloat64).reshape({ prediction.shape(0), prediction.shape(1) }); + Py_XDECREF(incoming); + return resultTensor; + } else { + double* data = reinterpret_cast(prediction.get_data()); + std::vector vPrediction(data, data + prediction.shape(0) * prediction.shape(1)); + auto resultTensor = torch::tensor(vPrediction, torch::kFloat64).reshape({ prediction.shape(0), prediction.shape(1) }); + Py_XDECREF(incoming); + return resultTensor; + } } float PyClassifier::score(torch::Tensor& X, torch::Tensor& y) { @@ -135,4 +151,4 @@ namespace pywrap { { this->hyperparameters = hyperparameters; } -} /* namespace pywrap */ \ No newline at end of file +} /* namespace pywrap */ diff --git a/pyclfs/PyClassifier.h b/pyclfs/PyClassifier.h index ebd6a27..a158d03 100644 --- a/pyclfs/PyClassifier.h +++ b/pyclfs/PyClassifier.h @@ -49,6 +49,7 @@ namespace pywrap { nlohmann::json hyperparameters; void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing = bayesnet::Smoothing_t::NONE) override {}; std::vector notes; + bool xgboost = false; private: PyWrap* pyWrap; std::string module; diff --git a/pyclfs/XGBoost.cc b/pyclfs/XGBoost.cc index e1bcfb6..c95223e 100644 --- a/pyclfs/XGBoost.cc +++ b/pyclfs/XGBoost.cc @@ -5,5 +5,6 @@ namespace pywrap { XGBoost::XGBoost() : PyClassifier("xgboost", "XGBClassifier", true) { validHyperparameters = { "tree_method", "early_stopping_rounds", "n_jobs" }; + xgboost = true; } } /* namespace pywrap */ \ No newline at end of file diff --git a/tests/TestPythonClassifiers.cc b/tests/TestPythonClassifiers.cc index b36aede..903f468 100644 --- a/tests/TestPythonClassifiers.cc +++ b/tests/TestPythonClassifiers.cc @@ -116,23 +116,30 @@ TEST_CASE("XGBoost", "[PyClassifiers]") clf.setHyperparameters(hyperparameters); auto score = clf.score(raw.Xt, raw.yt); REQUIRE(score == Catch::Approx(0.98).epsilon(raw.epsilon)); + std::cout << "XGBoost score: " << score << std::endl; } -// TEST_CASE("XGBoost predict proba", "[PyClassifiers]") -// { -// auto raw = RawDatasets("iris", true); -// auto clf = pywrap::XGBoost(); -// clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); -// // nlohmann::json hyperparameters = { "n_jobs=1" }; -// // clf.setHyperparameters(hyperparameters); -// auto predict = clf.predict(raw.Xt); -// for (int row = 0; row < predict.size(0); row++) { -// auto sum = 0.0; -// for (int col = 0; col < predict.size(1); col++) { -// std::cout << std::setw(12) << std::setprecision(10) << predict[row][col].item() << " "; -// sum += predict[row][col].item(); -// } -// std::cout << std::endl; -// // REQUIRE(sum == Catch::Approx(1.0).epsilon(raw.epsilon)); -// } -// std::cout << predict << std::endl; -// } \ No newline at end of file +TEST_CASE("XGBoost predict proba", "[PyClassifiers]") +{ + auto raw = RawDatasets("iris", true); + auto clf = pywrap::XGBoost(); + clf.fit(raw.Xt, raw.yt, raw.featurest, raw.classNamet, raw.statest); + // nlohmann::json hyperparameters = { "n_jobs=1" }; + // clf.setHyperparameters(hyperparameters); + auto predict_proba = clf.predict_proba(raw.Xt); + auto predict = clf.predict(raw.Xt); + // std::cout << "Predict proba: " << predict_proba << std::endl; + // std::cout << "Predict proba size: " << predict_proba.sizes() << std::endl; + // assert(predict.size(0) == predict_proba.size(0)); + for (int row = 0; row < predict_proba.size(0); row++) { + // auto sum = 0.0; + // std::cout << "Row " << std::setw(3) << row << ": "; + // for (int col = 0; col < predict_proba.size(1); col++) { + // std::cout << std::setw(9) << std::fixed << std::setprecision(7) << predict_proba[row][col].item() << " "; + // sum += predict_proba[row][col].item(); + // } + // std::cout << " -> " << std::setw(9) << std::fixed << std::setprecision(7) << sum << " -> " << torch::argmax(predict_proba[row]).item() << " = " << predict[row].item() << std::endl; + // // REQUIRE(sum == Catch::Approx(1.0).epsilon(raw.epsilon)); + REQUIRE(torch::argmax(predict_proba[row]).item() == predict[row].item()); + REQUIRE(torch::sum(predict_proba[row]).item() == Catch::Approx(1.0).epsilon(raw.epsilon)); + } +} \ No newline at end of file