Fix int type

This commit is contained in:
2024-06-09 00:29:55 +02:00
parent 7938df7f0f
commit c4e6c041fe
3 changed files with 12 additions and 12 deletions

View File

@@ -139,12 +139,12 @@ void process_file(const string& path, const string& file_name, bool class_last,
std::cout << std::fixed << std::setprecision(1) << X[0][i] << " " << data[i] << std::endl;
}
auto Xt = torch::tensor(X[0], torch::kFloat32);
auto yt = torch::tensor(y, torch::kInt64);
auto yt = torch::tensor(y, torch::kInt32);
//test.fit_t(Xt, yt);
auto result = test.fit_transform_t(Xt, yt);
std::cout << "Transformed data (torch)...: " << std::endl;
for (int i = 130; i < 135; i++) {
std::cout << std::fixed << std::setprecision(1) << Xt[i].item<float>() << " " << result[i].item<int64_t>() << std::endl;
std::cout << std::fixed << std::setprecision(1) << Xt[i].item<float>() << " " << result[i].item<int>() << std::endl;
}
auto disc = mdlp::BinDisc(3);
auto res_v = disc.fit_transform(X[0], y);
@@ -152,7 +152,7 @@ void process_file(const string& path, const string& file_name, bool class_last,
auto res_t = disc.transform_t(Xt);
std::cout << "Transformed data (BinDisc)...: " << std::endl;
for (int i = 130; i < 135; i++) {
std::cout << std::fixed << std::setprecision(1) << Xt[i].item<float>() << " " << res_v[i] << " " << res_t[i].item<int64_t>() << std::endl;
std::cout << std::fixed << std::setprecision(1) << Xt[i].item<float>() << " " << res_v[i] << " " << res_t[i].item<int>() << std::endl;
}
}