Optimize BoostAODE -> XBAODE #33

Merged
rmontanana merged 27 commits from WA2DE into main 2025-03-16 17:58:10 +00:00
32 changed files with 1456 additions and 50 deletions
Showing only changes of commit 86cccb6c7b - Show all commits

View File

@@ -407,12 +407,5 @@ namespace bayesnet {
}
return result;
}
torch::Tensor XSpode::predict(torch::Tensor& X)
{
auto X_ = TensorUtils::to_matrix(X);
auto predict = predict(X_);
return TensorUtils::to_tensor(predict);
}
}

View File

@@ -38,7 +38,6 @@ namespace bayesnet {
torch::Tensor predict(torch::Tensor& X) override;
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
torch::Tensor predict_proba(torch::Tensor& X) override;
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
protected:
void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;

1
lib/catch2 Submodule

Submodule lib/catch2 added at 029fe3b460