Fix input data dimensions

This commit is contained in:
2025-07-04 11:10:25 +02:00
parent 37a765e6b0
commit 57c6693842
3 changed files with 41 additions and 32 deletions

View File

@@ -22,9 +22,10 @@ public:
tie(Xt, yt, featurest, classNamet, statest) = loadDataset(file_name, true, discretize);
// Xv is always discretized
tie(Xv, yv, featuresv, classNamev, statesv) = loadFile(file_name);
auto yresized = yt.view({ yt.size(0), 1 });
dataset = torch::cat({ Xt, yresized }, 1);
nSamples = dataset.size(0);
// Xt is [features, samples], yt is [samples], need to reshape y to [1, samples] for concatenation
auto yresized = yt.view({ 1, yt.size(0) });
dataset = torch::cat({ Xt, yresized }, 0);
nSamples = dataset.size(1); // samples is the second dimension now
weights = torch::full({ nSamples }, 1.0 / nSamples, torch::kDouble);
weightsv = std::vector<double>(nSamples, 1.0 / nSamples);
classNumStates = discretize ? statest.at(classNamet).size() : 0;