Fix score with tensors and finis sample

This commit is contained in:
2023-07-26 13:29:47 +02:00
parent 4a54bd42a2
commit af7a1d2b40
3 changed files with 7 additions and 47 deletions

View File

@@ -202,24 +202,18 @@ int main(int argc, char** argv)
auto [train, test] = fold->getFold(i);
cout << "Fold: " << i + 1 << endl;
if (tensors) {
cout << "Xt shape: " << Xt.sizes() << endl;
cout << "yt shape: " << yt.sizes() << endl;
auto ttrain = torch::tensor(train, torch::kInt64);
auto ttest = torch::tensor(test, torch::kInt64);
torch::Tensor Xtraint = torch::index_select(Xt, 1, ttrain);
torch::Tensor ytraint = yt.index({ ttrain });
torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
torch::Tensor ytestt = yt.index({ ttest });
cout << "Train: " << Xtraint.size(0) << " x " << Xtraint.size(1) << " " << ytraint.size(0) << endl;
cout << "Test : " << Xtestt.size(0) << " x " << Xtestt.size(1) << " " << ytestt.size(0) << endl;
clf->fit(Xtraint, ytraint, features, className, states);
score_train = clf->score(Xtraint, ytraint);
score_test = clf->score(Xtestt, ytestt);
} else {
auto [Xtrain, ytrain] = extract_indices(train, Xd, y);
auto [Xtest, ytest] = extract_indices(test, Xd, y);
cout << "Train: " << Xtrain.size() << " x " << Xtrain[0].size() << " " << ytrain.size() << endl;
cout << "Test : " << Xtest.size() << " x " << Xtest[0].size() << " " << ytest.size() << endl;
clf->fit(Xtrain, ytrain, features, className, states);
score_train = clf->score(Xtrain, ytrain);
score_test = clf->score(Xtest, ytest);