Fix input data dimensions

This commit is contained in:
2025-07-04 11:10:25 +02:00
parent 37a765e6b0
commit 57c6693842
3 changed files with 41 additions and 32 deletions

View File

@@ -61,23 +61,26 @@ tuple<torch::Tensor, torch::Tensor, std::vector<std::string>, std::string, map<s
auto states = map<std::string, std::vector<int>>();
if (discretize_dataset) {
auto Xr = discretizeDataset(X, y);
// Create tensor as [samples, features] not [features, samples]
Xd = torch::zeros({ static_cast<int>(Xr[0].size()), static_cast<int>(Xr.size()) }, torch::kInt32);
// Create tensor as [features, samples] (bayesnet format)
// Xr has same structure as X: Xr[i] is i-th feature, Xr[i].size() is number of samples
Xd = torch::zeros({ static_cast<int>(Xr.size()), static_cast<int>(Xr[0].size()) }, torch::kInt32);
for (int i = 0; i < features.size(); ++i) {
states[features[i]] = std::vector<int>(*max_element(Xr[i].begin(), Xr[i].end()) + 1);
auto item = states.at(features[i]);
iota(begin(item), end(item), 0);
// Put data as column i (feature i)
Xd.index_put_({ "...", i }, torch::tensor(Xr[i], torch::kInt32));
// Put data as row i (feature i)
Xd.index_put_({ i, "..." }, torch::tensor(Xr[i], torch::kInt32));
}
states[className] = std::vector<int>(*max_element(y.begin(), y.end()) + 1);
iota(begin(states.at(className)), end(states.at(className)), 0);
} else {
// Create tensor as [samples, features] not [features, samples]
Xd = torch::zeros({ static_cast<int>(X[0].size()), static_cast<int>(X.size()) }, torch::kFloat32);
// Create tensor as [features, samples] (bayesnet format)
// X[i] is i-th feature, X[i].size() is number of samples
// We want tensor[features, samples], so [X.size(), X[0].size()]
Xd = torch::zeros({ static_cast<int>(X.size()), static_cast<int>(X[0].size()) }, torch::kFloat32);
for (int i = 0; i < features.size(); ++i) {
// Put data as column i (feature i)
Xd.index_put_({ "...", i }, torch::tensor(X[i]));
// Put data as row i (feature i)
Xd.index_put_({ i, "..." }, torch::tensor(X[i]));
}
}
return { Xd, torch::tensor(y, torch::kInt32), features, className, states };