Fix XSpode

This commit is contained in:
2025-03-10 14:21:01 +01:00
parent 7a8e0391dc
commit d1b235261e
3 changed files with 14 additions and 6 deletions

Binary file not shown.

View File

@@ -264,7 +264,7 @@ namespace bayesnet {
normalize(probs); normalize(probs);
return probs; return probs;
} }
std::vector<std::vector<double>> XSpode::predict_proba(const std::vector<std::vector<int>>& test_data) std::vector<std::vector<double>> XSpode::predict_proba(std::vector<std::vector<int>>& test_data)
{ {
int test_size = test_data[0].size(); int test_size = test_data[0].size();
int sample_size = test_data.size(); int sample_size = test_data.size();
@@ -390,14 +390,22 @@ namespace bayesnet {
torch::Tensor XSpode::predict(torch::Tensor& X) torch::Tensor XSpode::predict(torch::Tensor& X)
{ {
auto X_ = TensorUtils::to_matrix(X); auto X_ = TensorUtils::to_matrix(X);
auto result = predict(X_); auto result_v = predict(X_);
return TensorUtils::to_tensor(result); torch::Tensor result;
for (int i = 0; i < result_v.size(); ++i) {
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kInt32));
}
return result;
} }
torch::Tensor XSpode::predict_proba(torch::Tensor& X) torch::Tensor XSpode::predict_proba(torch::Tensor& X)
{ {
auto X_ = TensorUtils::to_matrix(X); auto X_ = TensorUtils::to_matrix(X);
auto result = predict_proba(X_); auto result_v = predict_proba(X_);
return TensorUtils::to_tensor<double>(result); torch::Tensor result;
for (int i = 0; i < result_v.size(); ++i) {
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble));
}
return result;
} }
torch::Tensor XSpode::predict(torch::Tensor& X) torch::Tensor XSpode::predict(torch::Tensor& X)
{ {

View File

@@ -18,7 +18,7 @@ namespace bayesnet {
public: public:
explicit XSpode(int spIndex); explicit XSpode(int spIndex);
std::vector<double> predict_proba(const std::vector<int>& instance) const; std::vector<double> predict_proba(const std::vector<int>& instance) const;
std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>& test_data); std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
int predict(const std::vector<int>& instance) const; int predict(const std::vector<int>& instance) const;
void normalize(std::vector<double>& v) const; void normalize(std::vector<double>& v) const;
std::string to_string() const; std::string to_string() const;