Optimize BoostAODE -> XBAODE #33

Merged
rmontanana merged 27 commits from WA2DE into main 2025-03-16 17:58:10 +00:00
32 changed files with 1456 additions and 50 deletions
Showing only changes of commit 86cccb6c7b - Show all commits

View File

@@ -407,12 +407,5 @@ namespace bayesnet {
} }
return result; return result;
} }
torch::Tensor XSpode::predict(torch::Tensor& X)
{
auto X_ = TensorUtils::to_matrix(X);
auto predict = predict(X_);
return TensorUtils::to_tensor(predict);
}
} }

View File

@@ -38,7 +38,6 @@ namespace bayesnet {
torch::Tensor predict(torch::Tensor& X) override; torch::Tensor predict(torch::Tensor& X) override;
std::vector<int> predict(std::vector<std::vector<int>>& X) override; std::vector<int> predict(std::vector<std::vector<int>>& X) override;
torch::Tensor predict_proba(torch::Tensor& X) override; torch::Tensor predict_proba(torch::Tensor& X) override;
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
protected: protected:
void buildModel(const torch::Tensor& weights) override; void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override; void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;

1
lib/catch2 Submodule

Submodule lib/catch2 added at 029fe3b460