Fix some mistakes in tensors treatment

This commit is contained in:
2023-07-26 01:39:01 +02:00
parent be06e475f0
commit 099b4bea09
18 changed files with 255 additions and 72 deletions

View File

@@ -55,16 +55,16 @@ tuple<Tensor, Tensor, vector<string>, string, map<string, vector<int>>> loadData
auto states = map<string, vector<int>>();
if (discretize_dataset) {
auto Xr = discretizeDataset(X, y);
Xd = torch::zeros({ static_cast<int64_t>(Xr[0].size()), static_cast<int64_t>(Xr.size()) }, torch::kInt64);
Xd = torch::zeros({ static_cast<int>(Xr[0].size()), static_cast<int>(Xr.size()) }, torch::kInt32);
for (int i = 0; i < features.size(); ++i) {
states[features[i]] = vector<int>(*max_element(Xr[i].begin(), Xr[i].end()) + 1);
iota(begin(states[features[i]]), end(states[features[i]]), 0);
Xd.index_put_({ "...", i }, torch::tensor(Xr[i], torch::kInt64));
Xd.index_put_({ "...", i }, torch::tensor(Xr[i], torch::kInt32));
}
states[className] = vector<int>(*max_element(y.begin(), y.end()) + 1);
iota(begin(states[className]), end(states[className]), 0);
} else {
Xd = torch::zeros({ static_cast<int64_t>(X[0].size()), static_cast<int64_t>(X.size()) }, torch::kFloat32);
Xd = torch::zeros({ static_cast<int>(X[0].size()), static_cast<int>(X.size()) }, torch::kFloat32);
for (int i = 0; i < features.size(); ++i) {
Xd.index_put_({ "...", i }, torch::tensor(X[i]));
}