Fix problem with tensors way

This commit is contained in:
Ricardo Montañana Gómez 2023-07-30 19:00:02 +02:00
parent 53697648e7
commit 43bb017d5d
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
10 changed files with 95 additions and 40 deletions

4
.vscode/launch.json vendored
View File

@ -12,10 +12,10 @@
"-m", "-m",
"TAN", "TAN",
"-p", "-p",
"../../data/", "/Users/rmontanana/Code/discretizbench/datasets/",
"--tensors" "--tensors"
], ],
"cwd": "${workspaceFolder}/build/sample/", //"cwd": "${workspaceFolder}/build/sample/",
}, },
{ {
"type": "lldb", "type": "lldb",

12
TAN_iris.dot Normal file
View File

@ -0,0 +1,12 @@
digraph BayesNet {
label=<BayesNet >
fontsize=30
fontcolor=blue
labelloc=t
layout=circo
class [shape=circle, fontcolor=red, fillcolor=lightblue, style=filled ]
class -> sepallength class -> sepalwidth class -> petallength class -> petalwidth petallength [shape=circle]
petallength -> sepallength petalwidth [shape=circle]
sepallength [shape=circle]
sepallength -> sepalwidth sepalwidth [shape=circle]
sepalwidth -> petalwidth }

View File

@ -1,7 +1,6 @@
#include <iostream> #include <iostream>
#include <torch/torch.h> #include <torch/torch.h>
#include <string> #include <string>
#include <thread>
#include <map> #include <map>
#include <argparse/argparse.hpp> #include <argparse/argparse.hpp>
#include "ArffFiles.h" #include "ArffFiles.h"
@ -42,7 +41,7 @@ bool file_exists(const std::string& name)
} }
pair<vector<vector<int>>, vector<int>> extract_indices(vector<int> indices, vector<vector<int>> X, vector<int> y) pair<vector<vector<int>>, vector<int>> extract_indices(vector<int> indices, vector<vector<int>> X, vector<int> y)
{ {
vector<vector<int>> Xr; vector<vector<int>> Xr; // nxm
vector<int> yr; vector<int> yr;
for (int col = 0; col < X.size(); ++col) { for (int col = 0; col < X.size(); ++col) {
Xr.push_back(vector<int>()); Xr.push_back(vector<int>());
@ -199,6 +198,7 @@ int main(int argc, char** argv)
torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest); torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest);
torch::Tensor ytestt = yt.index({ ttest }); torch::Tensor ytestt = yt.index({ ttest });
clf->fit(Xtraint, ytraint, features, className, states); clf->fit(Xtraint, ytraint, features, className, states);
auto temp = clf->predict(Xtraint);
score_train = clf->score(Xtraint, ytraint); score_train = clf->score(Xtraint, ytraint);
score_test = clf->score(Xtestt, ytestt); score_test = clf->score(Xtestt, ytestt);
} else { } else {

View File

@ -8,6 +8,7 @@ namespace bayesnet {
public: public:
virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0; virtual BaseClassifier& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0; virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) = 0;
torch::Tensor virtual predict(torch::Tensor& X) = 0;
vector<int> virtual predict(vector<vector<int>>& X) = 0; vector<int> virtual predict(vector<vector<int>>& X) = 0;
float virtual score(vector<vector<int>>& X, vector<int>& y) = 0; float virtual score(vector<vector<int>>& X, vector<int>& y) = 0;
float virtual score(torch::Tensor& X, torch::Tensor& y) = 0; float virtual score(torch::Tensor& X, torch::Tensor& y) = 0;

View File

@ -65,23 +65,14 @@ namespace bayesnet {
} }
} }
} }
Tensor Classifier::predict(Tensor& X) Tensor Classifier::predict(Tensor& X)
{ {
if (!fitted) { if (!fitted) {
throw logic_error("Classifier has not been fitted"); throw logic_error("Classifier has not been fitted");
} }
auto m_ = X.size(0); auto Xt = torch::transpose(X, 0, 1); // Base classifiers expect samples as columns
auto n_ = X.size(1); auto y_proba = model.predict(Xt);
//auto Xt = torch::transpose(X, 0, 1); return y_proba.argmax(1);
vector<vector<int>> Xd(n_, vector<int>(m_, 0));
for (auto i = 0; i < n_; i++) {
auto temp = X.index({ "...", i });
Xd[i] = vector<int>(temp.data_ptr<int>(), temp.data_ptr<int>() + temp.numel());
}
auto yp = model.predict(Xd);
auto ypred = torch::tensor(yp, torch::kInt32);
return ypred;
} }
vector<int> Classifier::predict(vector<vector<int>>& X) vector<int> Classifier::predict(vector<vector<int>>& X)
{ {
@ -102,8 +93,7 @@ namespace bayesnet {
if (!fitted) { if (!fitted) {
throw logic_error("Classifier has not been fitted"); throw logic_error("Classifier has not been fitted");
} }
auto Xt = torch::transpose(X, 0, 1); Tensor y_pred = predict(X);
Tensor y_pred = predict(Xt);
return (y_pred == y).sum().item<float>() / y.size(0); return (y_pred == y).sum().item<float>() / y.size(0);
} }
float Classifier::score(vector<vector<int>>& X, vector<int>& y) float Classifier::score(vector<vector<int>>& X, vector<int>& y)
@ -111,13 +101,13 @@ namespace bayesnet {
if (!fitted) { if (!fitted) {
throw logic_error("Classifier has not been fitted"); throw logic_error("Classifier has not been fitted");
} }
auto m_ = X[0].size(); // auto m_ = X[0].size();
auto n_ = X.size(); // auto n_ = X.size();
vector<vector<int>> Xd(n_, vector<int>(m_, 0)); // vector<vector<int>> Xd(n_, vector<int>(m_, 0));
for (auto i = 0; i < n_; i++) { // for (auto i = 0; i < n_; i++) {
Xd[i] = vector<int>(X[i].begin(), X[i].end()); // Xd[i] = vector<int>(X[i].begin(), X[i].end());
} // }
return model.score(Xd, y); return model.score(X, y);
} }
vector<string> Classifier::show() vector<string> Classifier::show()
{ {

View File

@ -35,7 +35,7 @@ namespace bayesnet {
int getNumberOfNodes() override; int getNumberOfNodes() override;
int getNumberOfEdges() override; int getNumberOfEdges() override;
int getNumberOfStates() override; int getNumberOfStates() override;
Tensor predict(Tensor& X); Tensor predict(Tensor& X) override;
vector<int> predict(vector<vector<int>>& X) override; vector<int> predict(vector<vector<int>>& X) override;
float score(Tensor& X, Tensor& y) override; float score(Tensor& X, Tensor& y) override;
float score(vector<vector<int>>& X, vector<int>& y) override; float score(vector<vector<int>>& X, vector<int>& y) override;

View File

@ -16,15 +16,22 @@ namespace bayesnet {
train(); train();
// Train models // Train models
n_models = models.size(); n_models = models.size();
auto Xt = torch::transpose(X, 0, 1);
for (auto i = 0; i < n_models; ++i) { for (auto i = 0; i < n_models; ++i) {
models[i]->fit(Xv, yv, features, className, states); if (Xv == vector<vector<int>>()) {
// fit with tensors
models[i]->fit(Xt, y, features, className, states);
} else {
// fit with vectors
models[i]->fit(Xv, yv, features, className, states);
}
} }
fitted = true; fitted = true;
return *this; return *this;
} }
Ensemble& Ensemble::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) Ensemble& Ensemble::fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states)
{ {
this->X = X; this->X = torch::transpose(X, 0, 1);
this->y = y; this->y = y;
Xv = vector<vector<int>>(); Xv = vector<vector<int>>();
yv = vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + y.size(0)); yv = vector<int>(y.data_ptr<int>(), y.data_ptr<int>() + y.size(0));
@ -41,17 +48,6 @@ namespace bayesnet {
yv = y; yv = y;
return build(features, className, states); return build(features, className, states);
} }
Tensor Ensemble::predict(Tensor& X)
{
if (!fitted) {
throw logic_error("Ensemble has not been fitted");
}
Tensor y_pred = torch::zeros({ X.size(0), n_models }, kInt32);
for (auto i = 0; i < n_models; ++i) {
y_pred.index_put_({ "...", i }, models[i]->predict(X));
}
return torch::tensor(voting(y_pred));
}
vector<int> Ensemble::voting(Tensor& y_pred) vector<int> Ensemble::voting(Tensor& y_pred)
{ {
auto y_pred_ = y_pred.accessor<int, 2>(); auto y_pred_ = y_pred.accessor<int, 2>();
@ -66,6 +62,18 @@ namespace bayesnet {
} }
return y_pred_final; return y_pred_final;
} }
Tensor Ensemble::predict(Tensor& X)
{
if (!fitted) {
throw logic_error("Ensemble has not been fitted");
}
Tensor y_pred = torch::zeros({ X.size(1), n_models }, kInt32);
for (auto i = 0; i < n_models; ++i) {
auto ypredict = models[i]->predict(X);
y_pred.index_put_({ "...", i }, ypredict);
}
return torch::tensor(voting(y_pred));
}
vector<int> Ensemble::predict(vector<vector<int>>& X) vector<int> Ensemble::predict(vector<vector<int>>& X)
{ {
if (!fitted) { if (!fitted) {

View File

@ -32,7 +32,7 @@ namespace bayesnet {
virtual ~Ensemble() = default; virtual ~Ensemble() = default;
Ensemble& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) override; Ensemble& fit(vector<vector<int>>& X, vector<int>& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
Ensemble& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override; Ensemble& fit(torch::Tensor& X, torch::Tensor& y, vector<string>& features, string className, map<string, vector<int>>& states) override;
Tensor predict(Tensor& X); Tensor predict(Tensor& X) override;
vector<int> predict(vector<vector<int>>& X) override; vector<int> predict(vector<vector<int>>& X) override;
float score(Tensor& X, Tensor& y) override; float score(Tensor& X, Tensor& y) override;
float score(vector<vector<int>>& X, vector<int>& y) override; float score(vector<vector<int>>& X, vector<int>& y) override;

View File

@ -170,6 +170,34 @@ namespace bayesnet {
} }
fitted = true; fitted = true;
} }
Tensor Network::predict_proba(const Tensor& samples)
{
if (!fitted) {
throw logic_error("You must call fit() before calling predict_proba()");
}
Tensor result = torch::zeros({ samples.size(0), classNumStates }, torch::kFloat64);
auto Xt = torch::transpose(samples, 0, 1);
for (int i = 0; i < samples.size(0); ++i) {
auto sample = Xt.index({ "...", i });
auto classProbabilities = predict_sample(sample);
result.index_put_({ i, "..." }, torch::tensor(classProbabilities, torch::kFloat64));
}
return result;
}
Tensor Network::predict(const Tensor& samples)
{
if (!fitted) {
throw logic_error("You must call fit() before calling predict()");
}
Tensor result = torch::zeros({ samples.size(0), classNumStates }, torch::kFloat64);
auto Xt = torch::transpose(samples, 0, 1);
for (int i = 0; i < samples.size(0); ++i) {
auto sample = Xt.index({ "...", i });
auto classProbabilities = predict_sample(sample);
result.index_put_({ i, "..." }, torch::tensor(classProbabilities, torch::kFloat64));
}
return result;
}
vector<int> Network::predict(const vector<vector<int>>& tsamples) vector<int> Network::predict(const vector<vector<int>>& tsamples)
{ {
@ -231,6 +259,19 @@ namespace bayesnet {
} }
return exactInference(evidence); return exactInference(evidence);
} }
vector<double> Network::predict_sample(const Tensor& sample)
{
// Ensure the sample size is equal to the number of features
if (sample.size(0) != features.size()) {
throw invalid_argument("Sample size (" + to_string(sample.size(0)) +
") does not match the number of features (" + to_string(features.size()) + ")");
}
map<string, int> evidence;
for (int i = 0; i < sample.size(0); ++i) {
evidence[features[i]] = sample[i].item<int>();
}
return exactInference(evidence);
}
double Network::computeFactor(map<string, int>& completeEvidence) double Network::computeFactor(map<string, int>& completeEvidence)
{ {
double result = 1.0; double result = 1.0;

View File

@ -18,6 +18,7 @@ namespace bayesnet {
torch::Tensor samples; torch::Tensor samples;
bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&); bool isCyclic(const std::string&, std::unordered_set<std::string>&, std::unordered_set<std::string>&);
vector<double> predict_sample(const vector<int>&); vector<double> predict_sample(const vector<int>&);
vector<double> predict_sample(const torch::Tensor&);
vector<double> exactInference(map<string, int>&); vector<double> exactInference(map<string, int>&);
double computeFactor(map<string, int>&); double computeFactor(map<string, int>&);
double mutual_info(torch::Tensor&, torch::Tensor&); double mutual_info(torch::Tensor&, torch::Tensor&);
@ -43,9 +44,11 @@ namespace bayesnet {
void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&); void fit(const vector<vector<int>>&, const vector<int>&, const vector<string>&, const string&);
void fit(torch::Tensor&, torch::Tensor&, const vector<string>&, const string&); void fit(torch::Tensor&, torch::Tensor&, const vector<string>&, const string&);
vector<int> predict(const vector<vector<int>>&); vector<int> predict(const vector<vector<int>>&);
torch::Tensor predict(const torch::Tensor&);
//Computes the conditional edge weight of variable index u and v conditioned on class_node //Computes the conditional edge weight of variable index u and v conditioned on class_node
torch::Tensor conditionalEdgeWeight(); torch::Tensor conditionalEdgeWeight();
vector<vector<double>> predict_proba(const vector<vector<int>>&); vector<vector<double>> predict_proba(const vector<vector<int>>&);
torch::Tensor predict_proba(const torch::Tensor&);
double score(const vector<vector<int>>&, const vector<int>&); double score(const vector<vector<int>>&, const vector<int>&);
vector<string> show(); vector<string> show();
vector<string> graph(const string& title); // Returns a vector of strings representing the graph in graphviz format vector<string> graph(const string& title); // Returns a vector of strings representing the graph in graphviz format