mirror of
https://github.com/rmontanana/mdlp.git
synced 2025-08-15 15:35:55 +00:00
Fix int type
This commit is contained in:
@@ -20,7 +20,7 @@ namespace mdlp {
|
||||
{
|
||||
auto num_elements = X_.numel();
|
||||
samples_t X(X_.data_ptr<precision_t>(), X_.data_ptr<precision_t>() + num_elements);
|
||||
labels_t y(y_.data_ptr<int64_t>(), y_.data_ptr<int64_t>() + num_elements);
|
||||
labels_t y(y_.data_ptr<int>(), y_.data_ptr<int>() + num_elements);
|
||||
fit(X, y);
|
||||
}
|
||||
torch::Tensor Discretizer::transform_t(torch::Tensor& X_)
|
||||
@@ -28,14 +28,14 @@ namespace mdlp {
|
||||
auto num_elements = X_.numel();
|
||||
samples_t X(X_.data_ptr<float>(), X_.data_ptr<float>() + num_elements);
|
||||
auto result = transform(X);
|
||||
return torch::tensor(result, torch::kInt64);
|
||||
return torch::tensor(result, torch::kInt32);
|
||||
}
|
||||
torch::Tensor Discretizer::fit_transform_t(torch::Tensor& X_, torch::Tensor& y_)
|
||||
{
|
||||
auto num_elements = X_.numel();
|
||||
samples_t X(X_.data_ptr<precision_t>(), X_.data_ptr<precision_t>() + num_elements);
|
||||
labels_t y(y_.data_ptr<int64_t>(), y_.data_ptr<int64_t>() + num_elements);
|
||||
labels_t y(y_.data_ptr<int>(), y_.data_ptr<int>() + num_elements);
|
||||
auto result = fit_transform(X, y);
|
||||
return torch::tensor(result, torch::kInt64);
|
||||
return torch::tensor(result, torch::kInt32);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user