Fix XSpode
This commit is contained in:
@@ -45,7 +45,7 @@ namespace bayesnet {
|
||||
n = X.size();
|
||||
buildModel(weights_);
|
||||
trainModel(weights_, smoothing);
|
||||
fitted=true;
|
||||
fitted = true;
|
||||
}
|
||||
|
||||
// --------------------------------------
|
||||
@@ -264,7 +264,7 @@ namespace bayesnet {
|
||||
normalize(probs);
|
||||
return probs;
|
||||
}
|
||||
std::vector<std::vector<double>> XSpode::predict_proba(std::vector<std::vector<int>>& test_data)
|
||||
std::vector<std::vector<double>> XSpode::predict_proba(std::vector<std::vector<int>>& test_data)
|
||||
{
|
||||
int test_size = test_data[0].size();
|
||||
int sample_size = test_data.size();
|
||||
@@ -397,22 +397,15 @@ namespace bayesnet {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
torch::Tensor XSpode::predict_proba(torch::Tensor& X)
|
||||
{
|
||||
auto X_ = TensorUtils::to_matrix(X);
|
||||
auto result_v = predict_proba(X_);
|
||||
torch::Tensor result;
|
||||
for (int i = 0; i < result_v.size(); ++i) {
|
||||
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
torch::Tensor XSpode::predict(torch::Tensor& X)
|
||||
torch::Tensor XSpode::predict_proba(torch::Tensor& X)
|
||||
{
|
||||
auto X_ = TensorUtils::to_matrix(X);
|
||||
auto predict = predict(X_);
|
||||
return TensorUtils::to_tensor(predict);
|
||||
auto result_v = predict_proba(X_);
|
||||
torch::Tensor result;
|
||||
for (int i = 0; i < result_v.size(); ++i) {
|
||||
result.index_put_({ i, "..." }, torch::tensor(result_v[i], torch::kDouble));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user