Fix XSpode
This commit is contained in:
@@ -407,12 +407,5 @@ namespace bayesnet {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
torch::Tensor XSpode::predict(torch::Tensor& X)
|
||||
{
|
||||
auto X_ = TensorUtils::to_matrix(X);
|
||||
auto predict = predict(X_);
|
||||
return TensorUtils::to_tensor(predict);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@@ -38,7 +38,6 @@ namespace bayesnet {
|
||||
torch::Tensor predict(torch::Tensor& X) override;
|
||||
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
||||
torch::Tensor predict_proba(torch::Tensor& X) override;
|
||||
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;
|
||||
|
1
lib/catch2
Submodule
1
lib/catch2
Submodule
Submodule lib/catch2 added at 029fe3b460
Submodule lib/folding updated: 9652853d69...2ac43e32ac
Submodule tests/lib/catch2 updated: 0321d2fce3...506276c592
Reference in New Issue
Block a user