diff --git a/bayesnet/classifiers/.XSPODE.h.swp b/bayesnet/classifiers/.XSPODE.h.swp new file mode 100644 index 0000000..024d03c Binary files /dev/null and b/bayesnet/classifiers/.XSPODE.h.swp differ diff --git a/bayesnet/classifiers/XSPODE.cc b/bayesnet/classifiers/XSPODE.cc index 4449849..b336b17 100644 --- a/bayesnet/classifiers/XSPODE.cc +++ b/bayesnet/classifiers/XSPODE.cc @@ -264,7 +264,7 @@ namespace bayesnet { normalize(probs); return probs; } - std::vector> XSpode::predict_proba(const std::vector>& test_data) + std::vector> XSpode::predict_proba(std::vector>& test_data) { int test_size = test_data[0].size(); int sample_size = test_data.size(); @@ -390,14 +390,22 @@ namespace bayesnet { torch::Tensor XSpode::predict(torch::Tensor& X) { auto X_ = TensorUtils::to_matrix(X); - auto result = predict(X_); - return TensorUtils::to_tensor(result); + auto result_v = predict(X_); + torch::Tensor result; + for (int i = 0; i < result_v.size(); ++i) { + result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kInt32)); + } + return result; } torch::Tensor XSpode::predict_proba(torch::Tensor& X) { auto X_ = TensorUtils::to_matrix(X); - auto result = predict_proba(X_); - return TensorUtils::to_tensor(result); + auto result_v = predict_proba(X_); + torch::Tensor result; + for (int i = 0; i < result_v.size(); ++i) { + result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble)); + } + return result; } torch::Tensor XSpode::predict(torch::Tensor& X) { diff --git a/bayesnet/classifiers/XSPODE.h b/bayesnet/classifiers/XSPODE.h index 7f57b22..fe29f34 100644 --- a/bayesnet/classifiers/XSPODE.h +++ b/bayesnet/classifiers/XSPODE.h @@ -18,7 +18,7 @@ namespace bayesnet { public: explicit XSpode(int spIndex); std::vector predict_proba(const std::vector& instance) const; - std::vector> predict_proba(const std::vector>& test_data); + std::vector> predict_proba(std::vector>& X) override; int predict(const std::vector& instance) const; void normalize(std::vector& v) const; std::string to_string() const;