Fix some mistakes in tensors treatment
This commit is contained in:
@@ -55,16 +55,16 @@ tuple<Tensor, Tensor, vector<string>, string, map<string, vector<int>>> loadData
|
||||
auto states = map<string, vector<int>>();
|
||||
if (discretize_dataset) {
|
||||
auto Xr = discretizeDataset(X, y);
|
||||
Xd = torch::zeros({ static_cast<int64_t>(Xr[0].size()), static_cast<int64_t>(Xr.size()) }, torch::kInt64);
|
||||
Xd = torch::zeros({ static_cast<int>(Xr[0].size()), static_cast<int>(Xr.size()) }, torch::kInt32);
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
states[features[i]] = vector<int>(*max_element(Xr[i].begin(), Xr[i].end()) + 1);
|
||||
iota(begin(states[features[i]]), end(states[features[i]]), 0);
|
||||
Xd.index_put_({ "...", i }, torch::tensor(Xr[i], torch::kInt64));
|
||||
Xd.index_put_({ "...", i }, torch::tensor(Xr[i], torch::kInt32));
|
||||
}
|
||||
states[className] = vector<int>(*max_element(y.begin(), y.end()) + 1);
|
||||
iota(begin(states[className]), end(states[className]), 0);
|
||||
} else {
|
||||
Xd = torch::zeros({ static_cast<int64_t>(X[0].size()), static_cast<int64_t>(X.size()) }, torch::kFloat32);
|
||||
Xd = torch::zeros({ static_cast<int>(X[0].size()), static_cast<int>(X.size()) }, torch::kFloat32);
|
||||
for (int i = 0; i < features.size(); ++i) {
|
||||
Xd.index_put_({ "...", i }, torch::tensor(X[i]));
|
||||
}
|
||||
|
Reference in New Issue
Block a user