Fix problem with tensors way

This commit is contained in:
2023-07-30 19:00:02 +02:00
parent 53697648e7
commit 43bb017d5d
10 changed files with 95 additions and 40 deletions

View File

@@ -1,7 +1,6 @@
#include <iostream>
#include <torch/torch.h>
#include <string>
#include <thread>
#include <map>
#include <argparse/argparse.hpp>
#include "ArffFiles.h"
@@ -42,7 +41,7 @@ bool file_exists(const std::string& name)
}
pair<vector<vector<int>>, vector<int>> extract_indices(vector<int> indices, vector<vector<int>> X, vector<int> y)
{
vector<vector<int>> Xr;
vector<vector<int>> Xr; // nxm
vector<int> yr;
for (int col = 0; col < X.size(); ++col) {
Xr.push_back(vector<int>());
@@ -199,6 +198,7 @@ int main(int argc, char** argv)
torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
torch::Tensor ytestt = yt.index({ ttest });
clf->fit(Xtraint, ytraint, features, className, states);
auto temp = clf->predict(Xtraint);
score_train = clf->score(Xtraint, ytraint);
score_test = clf->score(Xtestt, ytestt);
} else {