Fix problem with tensors way
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
#include <iostream>
|
||||
#include <torch/torch.h>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <map>
|
||||
#include <argparse/argparse.hpp>
|
||||
#include "ArffFiles.h"
|
||||
@@ -42,7 +41,7 @@ bool file_exists(const std::string& name)
|
||||
}
|
||||
pair<vector<vector<int>>, vector<int>> extract_indices(vector<int> indices, vector<vector<int>> X, vector<int> y)
|
||||
{
|
||||
vector<vector<int>> Xr;
|
||||
vector<vector<int>> Xr; // nxm
|
||||
vector<int> yr;
|
||||
for (int col = 0; col < X.size(); ++col) {
|
||||
Xr.push_back(vector<int>());
|
||||
@@ -199,6 +198,7 @@ int main(int argc, char** argv)
|
||||
torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
|
||||
torch::Tensor ytestt = yt.index({ ttest });
|
||||
clf->fit(Xtraint, ytraint, features, className, states);
|
||||
auto temp = clf->predict(Xtraint);
|
||||
score_train = clf->score(Xtraint, ytraint);
|
||||
score_test = clf->score(Xtestt, ytestt);
|
||||
} else {
|
||||
|
Reference in New Issue
Block a user