Fix int type

This commit is contained in:
2024-06-09 00:29:55 +02:00
parent 7938df7f0f
commit c4e6c041fe
3 changed files with 12 additions and 12 deletions

View File

@@ -20,7 +20,7 @@ namespace mdlp {
{ {
auto num_elements = X_.numel(); auto num_elements = X_.numel();
samples_t X(X_.data_ptr<precision_t>(), X_.data_ptr<precision_t>() + num_elements); samples_t X(X_.data_ptr<precision_t>(), X_.data_ptr<precision_t>() + num_elements);
labels_t y(y_.data_ptr<int64_t>(), y_.data_ptr<int64_t>() + num_elements); labels_t y(y_.data_ptr<int>(), y_.data_ptr<int>() + num_elements);
fit(X, y); fit(X, y);
} }
torch::Tensor Discretizer::transform_t(torch::Tensor& X_) torch::Tensor Discretizer::transform_t(torch::Tensor& X_)
@@ -28,14 +28,14 @@ namespace mdlp {
auto num_elements = X_.numel(); auto num_elements = X_.numel();
samples_t X(X_.data_ptr<float>(), X_.data_ptr<float>() + num_elements); samples_t X(X_.data_ptr<float>(), X_.data_ptr<float>() + num_elements);
auto result = transform(X); auto result = transform(X);
return torch::tensor(result, torch::kInt64); return torch::tensor(result, torch::kInt32);
} }
torch::Tensor Discretizer::fit_transform_t(torch::Tensor& X_, torch::Tensor& y_) torch::Tensor Discretizer::fit_transform_t(torch::Tensor& X_, torch::Tensor& y_)
{ {
auto num_elements = X_.numel(); auto num_elements = X_.numel();
samples_t X(X_.data_ptr<precision_t>(), X_.data_ptr<precision_t>() + num_elements); samples_t X(X_.data_ptr<precision_t>(), X_.data_ptr<precision_t>() + num_elements);
labels_t y(y_.data_ptr<int64_t>(), y_.data_ptr<int64_t>() + num_elements); labels_t y(y_.data_ptr<int>(), y_.data_ptr<int>() + num_elements);
auto result = fit_transform(X, y); auto result = fit_transform(X, y);
return torch::tensor(result, torch::kInt64); return torch::tensor(result, torch::kInt32);
} }
} }

View File

@@ -139,12 +139,12 @@ void process_file(const string& path, const string& file_name, bool class_last,
std::cout << std::fixed << std::setprecision(1) << X[0][i] << " " << data[i] << std::endl; std::cout << std::fixed << std::setprecision(1) << X[0][i] << " " << data[i] << std::endl;
} }
auto Xt = torch::tensor(X[0], torch::kFloat32); auto Xt = torch::tensor(X[0], torch::kFloat32);
auto yt = torch::tensor(y, torch::kInt64); auto yt = torch::tensor(y, torch::kInt32);
//test.fit_t(Xt, yt); //test.fit_t(Xt, yt);
auto result = test.fit_transform_t(Xt, yt); auto result = test.fit_transform_t(Xt, yt);
std::cout << "Transformed data (torch)...: " << std::endl; std::cout << "Transformed data (torch)...: " << std::endl;
for (int i = 130; i < 135; i++) { for (int i = 130; i < 135; i++) {
std::cout << std::fixed << std::setprecision(1) << Xt[i].item<float>() << " " << result[i].item<int64_t>() << std::endl; std::cout << std::fixed << std::setprecision(1) << Xt[i].item<float>() << " " << result[i].item<int>() << std::endl;
} }
auto disc = mdlp::BinDisc(3); auto disc = mdlp::BinDisc(3);
auto res_v = disc.fit_transform(X[0], y); auto res_v = disc.fit_transform(X[0], y);
@@ -152,7 +152,7 @@ void process_file(const string& path, const string& file_name, bool class_last,
auto res_t = disc.transform_t(Xt); auto res_t = disc.transform_t(Xt);
std::cout << "Transformed data (BinDisc)...: " << std::endl; std::cout << "Transformed data (BinDisc)...: " << std::endl;
for (int i = 130; i < 135; i++) { for (int i = 130; i < 135; i++) {
std::cout << std::fixed << std::setprecision(1) << Xt[i].item<float>() << " " << res_v[i] << " " << res_t[i].item<int64_t>() << std::endl; std::cout << std::fixed << std::setprecision(1) << Xt[i].item<float>() << " " << res_v[i] << " " << res_t[i].item<int>() << std::endl;
} }
} }

View File

@@ -335,10 +335,10 @@ namespace mdlp {
auto Xtt = fit_transform(X[0], file.getY()); auto Xtt = fit_transform(X[0], file.getY());
EXPECT_EQ(expected, Xtt); EXPECT_EQ(expected, Xtt);
auto Xt_t = torch::tensor(X[0], torch::kFloat32); auto Xt_t = torch::tensor(X[0], torch::kFloat32);
auto y_t = torch::tensor(file.getY(), torch::kInt64); auto y_t = torch::tensor(file.getY(), torch::kInt32);
auto Xtt_t = fit_transform_t(Xt_t, y_t); auto Xtt_t = fit_transform_t(Xt_t, y_t);
for (int i = 0; i < expected.size(); i++) for (int i = 0; i < expected.size(); i++)
EXPECT_EQ(expected[i], Xtt_t[i].item<int64_t>()); EXPECT_EQ(expected[i], Xtt_t[i].item<int>());
} }
TEST_F(TestBinDisc4Q, irisQuantile) TEST_F(TestBinDisc4Q, irisQuantile)
{ {
@@ -352,13 +352,13 @@ namespace mdlp {
auto Xtt = fit_transform(X[0], file.getY()); auto Xtt = fit_transform(X[0], file.getY());
EXPECT_EQ(expected, Xtt); EXPECT_EQ(expected, Xtt);
auto Xt_t = torch::tensor(X[0], torch::kFloat32); auto Xt_t = torch::tensor(X[0], torch::kFloat32);
auto y_t = torch::tensor(file.getY(), torch::kInt64); auto y_t = torch::tensor(file.getY(), torch::kInt32);
auto Xtt_t = fit_transform_t(Xt_t, y_t); auto Xtt_t = fit_transform_t(Xt_t, y_t);
for (int i = 0; i < expected.size(); i++) for (int i = 0; i < expected.size(); i++)
EXPECT_EQ(expected[i], Xtt_t[i].item<int64_t>()); EXPECT_EQ(expected[i], Xtt_t[i].item<int>());
fit_t(Xt_t, y_t); fit_t(Xt_t, y_t);
auto Xt_t2 = transform_t(Xt_t); auto Xt_t2 = transform_t(Xt_t);
for (int i = 0; i < expected.size(); i++) for (int i = 0; i < expected.size(); i++)
EXPECT_EQ(expected[i], Xt_t2[i].item<int64_t>()); EXPECT_EQ(expected[i], Xt_t2[i].item<int>());
} }
} }