Fix score with tensors and finis sample

This commit is contained in:
Ricardo Montañana Gómez 2023-07-26 13:29:47 +02:00
parent 4a54bd42a2
commit af7a1d2b40
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 7 additions and 47 deletions

View File

@ -202,24 +202,18 @@ int main(int argc, char** argv)
auto [train, test] = fold->getFold(i); auto [train, test] = fold->getFold(i);
cout << "Fold: " << i + 1 << endl; cout << "Fold: " << i + 1 << endl;
if (tensors) { if (tensors) {
cout << "Xt shape: " << Xt.sizes() << endl;
cout << "yt shape: " << yt.sizes() << endl;
auto ttrain = torch::tensor(train, torch::kInt64); auto ttrain = torch::tensor(train, torch::kInt64);
auto ttest = torch::tensor(test, torch::kInt64); auto ttest = torch::tensor(test, torch::kInt64);
torch::Tensor Xtraint = torch::index_select(Xt, 1, ttrain); torch::Tensor Xtraint = torch::index_select(Xt, 1, ttrain);
torch::Tensor ytraint = yt.index({ ttrain }); torch::Tensor ytraint = yt.index({ ttrain });
torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest); torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
torch::Tensor ytestt = yt.index({ ttest }); torch::Tensor ytestt = yt.index({ ttest });
cout << "Train: " << Xtraint.size(0) << " x " << Xtraint.size(1) << " " << ytraint.size(0) << endl;
cout << "Test : " << Xtestt.size(0) << " x " << Xtestt.size(1) << " " << ytestt.size(0) << endl;
clf->fit(Xtraint, ytraint, features, className, states); clf->fit(Xtraint, ytraint, features, className, states);
score_train = clf->score(Xtraint, ytraint); score_train = clf->score(Xtraint, ytraint);
score_test = clf->score(Xtestt, ytestt); score_test = clf->score(Xtestt, ytestt);
} else { } else {
auto [Xtrain, ytrain] = extract_indices(train, Xd, y); auto [Xtrain, ytrain] = extract_indices(train, Xd, y);
auto [Xtest, ytest] = extract_indices(test, Xd, y); auto [Xtest, ytest] = extract_indices(test, Xd, y);
cout << "Train: " << Xtrain.size() << " x " << Xtrain[0].size() << " " << ytrain.size() << endl;
cout << "Test : " << Xtest.size() << " x " << Xtest[0].size() << " " << ytest.size() << endl;
clf->fit(Xtrain, ytrain, features, className, states); clf->fit(Xtrain, ytrain, features, className, states);
score_train = clf->score(Xtrain, ytrain); score_train = clf->score(Xtrain, ytrain);
score_test = clf->score(Xtest, ytest); score_test = clf->score(Xtest, ytest);

View File

@ -7,11 +7,6 @@ namespace bayesnet {
Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {} Classifier::Classifier(Network model) : model(model), m(0), n(0), metrics(Metrics()), fitted(false) {}
Classifier& Classifier::build(vector<string>& features, string className, map<string, vector<int>>& states) Classifier& Classifier::build(vector<string>& features, string className, map<string, vector<int>>& states)
{ {
cout << "Building classifier..." << endl;
cout << "X sizes = " << X.sizes() << endl;
cout << "y sizes = " << y.sizes() << endl;
cout << "Xv size = " << Xv.size() << endl;
cout << "yv size = " << yv.size() << endl;
dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1); dataset = torch::cat({ X, y.view({y.size(0), 1}) }, 1);
this->features = features; this->features = features;
this->className = className; this->className = className;
@ -21,8 +16,10 @@ namespace bayesnet {
metrics = Metrics(dataset, features, className, n_classes); metrics = Metrics(dataset, features, className, n_classes);
train(); train();
if (Xv == vector<vector<int>>()) { if (Xv == vector<vector<int>>()) {
// fit with tensors
model.fit(X, y, features, className); model.fit(X, y, features, className);
} else { } else {
// fit with vectors
model.fit(Xv, yv, features, className); model.fit(Xv, yv, features, className);
} }
fitted = true; fitted = true;
@ -33,10 +30,6 @@ namespace bayesnet {
this->X = torch::transpose(X, 0, 1); this->X = torch::transpose(X, 0, 1);
this->y = y; this->y = y;
Xv = vector<vector<int>>(); Xv = vector<vector<int>>();
for (int i = 0; i < X.size(1); ++i) {
auto temp = X.index({ "...", i });
Xv.push_back(vector<int>(temp.data_ptr<int>(), temp.data_ptr<int>() + temp.numel()));
}
yv = vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + y.size(0)); yv = vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + y.size(0));
return build(features, className, states); return build(features, className, states);
} }
@ -109,7 +102,8 @@ namespace bayesnet {
if (!fitted) { if (!fitted) {
throw logic_error("Classifier has not been fitted"); throw logic_error("Classifier has not been fitted");
} }
Tensor y_pred = predict(X); auto Xt = torch::transpose(X, 0, 1);
Tensor y_pred = predict(Xt);
return (y_pred == y).sum().item<float>() / y.size(0); return (y_pred == y).sum().item<float>() / y.size(0);
} }
float Classifier::score(vector<vector<int>>& X, vector<int>& y) float Classifier::score(vector<vector<int>>& X, vector<int>& y)

View File

@ -99,6 +99,7 @@ namespace bayesnet {
features = featureNames; features = featureNames;
this->className = className; this->className = className;
dataset.clear(); dataset.clear();
// Specific part
classNumStates = torch::max(y).item<int>() + 1; classNumStates = torch::max(y).item<int>() + 1;
samples = torch::cat({ X, y.view({ y.size(0), 1 }) }, 1); samples = torch::cat({ X, y.view({ y.size(0), 1 }) }, 1);
for (int i = 0; i < featureNames.size(); ++i) { for (int i = 0; i < featureNames.size(); ++i) {
@ -110,36 +111,6 @@ namespace bayesnet {
dataset[featureNames[i]] = k; dataset[featureNames[i]] = k;
} }
dataset[className] = vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + y.size(0)); dataset[className] = vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + y.size(0));
// //
// // Check if data is ok
// cout << "******************************************************************" << endl;
// cout << "Check samples, sizes: " << samples.sizes() << endl;
// for (auto i = 0; i < features.size(); ++i) {
// cout << featureNames[i] << ": " << nodes[featureNames[i]]->getNumStates() << ": torch:max " << torch::max(samples.index({ "...", i })).item<int>() + 1 << " dataset" << *max_element(dataset[featureNames[i]].begin(), dataset[featureNames[i]].end()) + 1 << endl;
// }
// cout << className << ": " << nodes[className]->getNumStates() << ": torch:max " << torch::max(samples.index({ "...", -1 })) + 1 << endl;
// cout << "******************************************************************" << endl;
// //
// //
/*
*/
for (int i = 0; i < features.size(); ++i) {
cout << "Checking " << features[i] << endl;
auto column = torch::flatten(X.index({ "...", i }));
auto k = vector<int>();
for (auto i = 0; i < X.size(0); ++i) {
k.push_back(column[i].item<int>());
}
if (k != dataset[features[i]]) {
throw invalid_argument("Dataset and samples do not match");
}
}
/*
*/
completeFit(); completeFit();
} }
void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<string>& featureNames, const string& className) void Network::fit(const vector<vector<int>>& input_data, const vector<int>& labels, const vector<string>& featureNames, const string& className)
@ -147,6 +118,8 @@ namespace bayesnet {
features = featureNames; features = featureNames;
this->className = className; this->className = className;
dataset.clear(); dataset.clear();
// Specific part
classNumStates = *max_element(labels.begin(), labels.end()) + 1;
// Build dataset & tensor of samples // Build dataset & tensor of samples
samples = torch::zeros({ static_cast<int>(input_data[0].size()), static_cast<int>(input_data.size() + 1) }, torch::kInt32); samples = torch::zeros({ static_cast<int>(input_data[0].size()), static_cast<int>(input_data.size() + 1) }, torch::kInt32);
for (int i = 0; i < featureNames.size(); ++i) { for (int i = 0; i < featureNames.size(); ++i) {
@ -155,7 +128,6 @@ namespace bayesnet {
} }
dataset[className] = labels; dataset[className] = labels;
samples.index_put_({ "...", -1 }, torch::tensor(labels, torch::kInt32)); samples.index_put_({ "...", -1 }, torch::tensor(labels, torch::kInt32));
classNumStates = *max_element(labels.begin(), labels.end()) + 1;
completeFit(); completeFit();
} }
void Network::completeFit() void Network::completeFit()