diff --git a/bayesnet/classifiers/XSPODE.cc b/bayesnet/classifiers/XSPODE.cc index b336b17..ff7d332 100644 --- a/bayesnet/classifiers/XSPODE.cc +++ b/bayesnet/classifiers/XSPODE.cc @@ -45,7 +45,7 @@ namespace bayesnet { n = X.size(); buildModel(weights_); trainModel(weights_, smoothing); - fitted=true; + fitted = true; } // -------------------------------------- @@ -264,7 +264,7 @@ namespace bayesnet { normalize(probs); return probs; } - std::vector> XSpode::predict_proba(std::vector>& test_data) + std::vector> XSpode::predict_proba(std::vector>& test_data) { int test_size = test_data[0].size(); int sample_size = test_data.size(); @@ -397,22 +397,15 @@ namespace bayesnet { } return result; } - torch::Tensor XSpode::predict_proba(torch::Tensor& X) - { - auto X_ = TensorUtils::to_matrix(X); - auto result_v = predict_proba(X_); - torch::Tensor result; - for (int i = 0; i < result_v.size(); ++i) { - result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble)); - } - return result; - } - torch::Tensor XSpode::predict(torch::Tensor& X) + torch::Tensor XSpode::predict_proba(torch::Tensor& X) { auto X_ = TensorUtils::to_matrix(X); - auto predict = predict(X_); - return TensorUtils::to_tensor(predict); + auto result_v = predict_proba(X_); + torch::Tensor result; + for (int i = 0; i < result_v.size(); ++i) { + result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble)); + } + return result; } - } diff --git a/bayesnet/classifiers/XSPODE.h b/bayesnet/classifiers/XSPODE.h index fe29f34..215e8ab 100644 --- a/bayesnet/classifiers/XSPODE.h +++ b/bayesnet/classifiers/XSPODE.h @@ -28,7 +28,7 @@ namespace bayesnet { int getNumberOfStates() const override; int getClassNumStates() const override; std::vector& getStates(); - std::vector graph(const std::string& title) const override { return std::vector({title}); } + std::vector graph(const std::string& title) const override { return std::vector({ title }); } void fit(std::vector>& X, std::vector& y, torch::Tensor& weights_, const Smoothing_t smoothing); void setHyperparameters(const nlohmann::json& hyperparameters_) override; @@ -38,7 +38,6 @@ namespace bayesnet { torch::Tensor predict(torch::Tensor& X) override; std::vector predict(std::vector>& X) override; torch::Tensor predict_proba(torch::Tensor& X) override; - std::vector> predict_proba(std::vector>& X) override; protected: void buildModel(const torch::Tensor& weights) override; void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override; diff --git a/lib/catch2 b/lib/catch2 new file mode 160000 index 0000000..029fe3b --- /dev/null +++ b/lib/catch2 @@ -0,0 +1 @@ +Subproject commit 029fe3b4609dd84cd939b73357f37bbb75bcf82f diff --git a/lib/folding b/lib/folding index 9652853..2ac43e3 160000 --- a/lib/folding +++ b/lib/folding @@ -1 +1 @@ -Subproject commit 9652853d692ed3b8a38d89f70559209ffb988020 +Subproject commit 2ac43e32ac1eac0c986702ec526cf5367a565ef0 diff --git a/tests/lib/catch2 b/tests/lib/catch2 index 0321d2f..506276c 160000 --- a/tests/lib/catch2 +++ b/tests/lib/catch2 @@ -1 +1 @@ -Subproject commit 0321d2fce328b5e2ad106a8230ff20e0d5bf5501 +Subproject commit 506276c59217429c93abd2fe9507c7f45eb81072