Fix XSpode

This commit is contained in:
2025-03-10 14:23:47 +01:00
parent d1b235261e
commit 86cccb6c7b
5 changed files with 13 additions and 20 deletions

View File

@@ -45,7 +45,7 @@ namespace bayesnet {
n = X.size(); n = X.size();
buildModel(weights_); buildModel(weights_);
trainModel(weights_, smoothing); trainModel(weights_, smoothing);
fitted=true; fitted = true;
} }
// -------------------------------------- // --------------------------------------
@@ -264,7 +264,7 @@ namespace bayesnet {
normalize(probs); normalize(probs);
return probs; return probs;
} }
std::vector<std::vector<double>> XSpode::predict_proba(std::vector<std::vector<int>>& test_data) std::vector<std::vector<double>> XSpode::predict_proba(std::vector<std::vector<int>>& test_data)
{ {
int test_size = test_data[0].size(); int test_size = test_data[0].size();
int sample_size = test_data.size(); int sample_size = test_data.size();
@@ -397,22 +397,15 @@ namespace bayesnet {
} }
return result; return result;
} }
torch::Tensor XSpode::predict_proba(torch::Tensor& X) torch::Tensor XSpode::predict_proba(torch::Tensor& X)
{
auto X_ = TensorUtils::to_matrix(X);
auto result_v = predict_proba(X_);
torch::Tensor result;
for (int i = 0; i < result_v.size(); ++i) {
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble));
}
return result;
}
torch::Tensor XSpode::predict(torch::Tensor& X)
{ {
auto X_ = TensorUtils::to_matrix(X); auto X_ = TensorUtils::to_matrix(X);
auto predict = predict(X_); auto result_v = predict_proba(X_);
return TensorUtils::to_tensor(predict); torch::Tensor result;
for (int i = 0; i < result_v.size(); ++i) {
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble));
}
return result;
} }
} }

View File

@@ -28,7 +28,7 @@ namespace bayesnet {
int getNumberOfStates() const override; int getNumberOfStates() const override;
int getClassNumStates() const override; int getClassNumStates() const override;
std::vector<int>& getStates(); std::vector<int>& getStates();
std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({title}); } std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({ title }); }
void fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing); void fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing);
void setHyperparameters(const nlohmann::json& hyperparameters_) override; void setHyperparameters(const nlohmann::json& hyperparameters_) override;
@@ -38,7 +38,6 @@ namespace bayesnet {
torch::Tensor predict(torch::Tensor& X) override; torch::Tensor predict(torch::Tensor& X) override;
std::vector<int> predict(std::vector<std::vector<int>>& X) override; std::vector<int> predict(std::vector<std::vector<int>>& X) override;
torch::Tensor predict_proba(torch::Tensor& X) override; torch::Tensor predict_proba(torch::Tensor& X) override;
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
protected: protected:
void buildModel(const torch::Tensor& weights) override; void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override; void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;

1
lib/catch2 Submodule

Submodule lib/catch2 added at 029fe3b460