Fix int type

This commit is contained in:
2024-06-09 00:29:55 +02:00
parent 7938df7f0f
commit c4e6c041fe
3 changed files with 12 additions and 12 deletions

View File

@@ -335,10 +335,10 @@ namespace mdlp {
auto Xtt = fit_transform(X[0], file.getY());
EXPECT_EQ(expected, Xtt);
auto Xt_t = torch::tensor(X[0], torch::kFloat32);
auto y_t = torch::tensor(file.getY(), torch::kInt64);
auto y_t = torch::tensor(file.getY(), torch::kInt32);
auto Xtt_t = fit_transform_t(Xt_t, y_t);
for (int i = 0; i < expected.size(); i++)
EXPECT_EQ(expected[i], Xtt_t[i].item<int64_t>());
EXPECT_EQ(expected[i], Xtt_t[i].item<int>());
}
TEST_F(TestBinDisc4Q, irisQuantile)
{
@@ -352,13 +352,13 @@ namespace mdlp {
auto Xtt = fit_transform(X[0], file.getY());
EXPECT_EQ(expected, Xtt);
auto Xt_t = torch::tensor(X[0], torch::kFloat32);
auto y_t = torch::tensor(file.getY(), torch::kInt64);
auto y_t = torch::tensor(file.getY(), torch::kInt32);
auto Xtt_t = fit_transform_t(Xt_t, y_t);
for (int i = 0; i < expected.size(); i++)
EXPECT_EQ(expected[i], Xtt_t[i].item<int64_t>());
EXPECT_EQ(expected[i], Xtt_t[i].item<int>());
fit_t(Xt_t, y_t);
auto Xt_t2 = transform_t(Xt_t);
for (int i = 0; i < expected.size(); i++)
EXPECT_EQ(expected[i], Xt_t2[i].item<int64_t>());
EXPECT_EQ(expected[i], Xt_t2[i].item<int>());
}
}