diff --git a/Discretizer.cpp b/Discretizer.cpp index f616eb8..9d637ca 100644 --- a/Discretizer.cpp +++ b/Discretizer.cpp @@ -20,7 +20,7 @@ namespace mdlp { { auto num_elements = X_.numel(); samples_t X(X_.data_ptr(), X_.data_ptr() + num_elements); - labels_t y(y_.data_ptr(), y_.data_ptr() + num_elements); + labels_t y(y_.data_ptr(), y_.data_ptr() + num_elements); fit(X, y); } torch::Tensor Discretizer::transform_t(torch::Tensor& X_) @@ -28,14 +28,14 @@ namespace mdlp { auto num_elements = X_.numel(); samples_t X(X_.data_ptr(), X_.data_ptr() + num_elements); auto result = transform(X); - return torch::tensor(result, torch::kInt64); + return torch::tensor(result, torch::kInt32); } torch::Tensor Discretizer::fit_transform_t(torch::Tensor& X_, torch::Tensor& y_) { auto num_elements = X_.numel(); samples_t X(X_.data_ptr(), X_.data_ptr() + num_elements); - labels_t y(y_.data_ptr(), y_.data_ptr() + num_elements); + labels_t y(y_.data_ptr(), y_.data_ptr() + num_elements); auto result = fit_transform(X, y); - return torch::tensor(result, torch::kInt64); + return torch::tensor(result, torch::kInt32); } } \ No newline at end of file diff --git a/sample/sample.cpp b/sample/sample.cpp index 72a16d8..376c407 100644 --- a/sample/sample.cpp +++ b/sample/sample.cpp @@ -139,12 +139,12 @@ void process_file(const string& path, const string& file_name, bool class_last, std::cout << std::fixed << std::setprecision(1) << X[0][i] << " " << data[i] << std::endl; } auto Xt = torch::tensor(X[0], torch::kFloat32); - auto yt = torch::tensor(y, torch::kInt64); + auto yt = torch::tensor(y, torch::kInt32); //test.fit_t(Xt, yt); auto result = test.fit_transform_t(Xt, yt); std::cout << "Transformed data (torch)...: " << std::endl; for (int i = 130; i < 135; i++) { - std::cout << std::fixed << std::setprecision(1) << Xt[i].item() << " " << result[i].item() << std::endl; + std::cout << std::fixed << std::setprecision(1) << Xt[i].item() << " " << result[i].item() << std::endl; } auto disc = mdlp::BinDisc(3); auto res_v = disc.fit_transform(X[0], y); @@ -152,7 +152,7 @@ void process_file(const string& path, const string& file_name, bool class_last, auto res_t = disc.transform_t(Xt); std::cout << "Transformed data (BinDisc)...: " << std::endl; for (int i = 130; i < 135; i++) { - std::cout << std::fixed << std::setprecision(1) << Xt[i].item() << " " << res_v[i] << " " << res_t[i].item() << std::endl; + std::cout << std::fixed << std::setprecision(1) << Xt[i].item() << " " << res_v[i] << " " << res_t[i].item() << std::endl; } } diff --git a/tests/BinDisc_unittest.cpp b/tests/BinDisc_unittest.cpp index cad67a5..2d4437c 100644 --- a/tests/BinDisc_unittest.cpp +++ b/tests/BinDisc_unittest.cpp @@ -335,10 +335,10 @@ namespace mdlp { auto Xtt = fit_transform(X[0], file.getY()); EXPECT_EQ(expected, Xtt); auto Xt_t = torch::tensor(X[0], torch::kFloat32); - auto y_t = torch::tensor(file.getY(), torch::kInt64); + auto y_t = torch::tensor(file.getY(), torch::kInt32); auto Xtt_t = fit_transform_t(Xt_t, y_t); for (int i = 0; i < expected.size(); i++) - EXPECT_EQ(expected[i], Xtt_t[i].item()); + EXPECT_EQ(expected[i], Xtt_t[i].item()); } TEST_F(TestBinDisc4Q, irisQuantile) { @@ -352,13 +352,13 @@ namespace mdlp { auto Xtt = fit_transform(X[0], file.getY()); EXPECT_EQ(expected, Xtt); auto Xt_t = torch::tensor(X[0], torch::kFloat32); - auto y_t = torch::tensor(file.getY(), torch::kInt64); + auto y_t = torch::tensor(file.getY(), torch::kInt32); auto Xtt_t = fit_transform_t(Xt_t, y_t); for (int i = 0; i < expected.size(); i++) - EXPECT_EQ(expected[i], Xtt_t[i].item()); + EXPECT_EQ(expected[i], Xtt_t[i].item()); fit_t(Xt_t, y_t); auto Xt_t2 = transform_t(Xt_t); for (int i = 0; i < expected.size(); i++) - EXPECT_EQ(expected[i], Xt_t2[i].item()); + EXPECT_EQ(expected[i], Xt_t2[i].item()); } }