From 43bb017d5d7d1bb24bf624e8bc6dd2af80298971 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sun, 30 Jul 2023 19:00:02 +0200 Subject: [PATCH] Fix problem with tensors way --- .vscode/launch.json | 4 ++-- TAN_iris.dot | 12 ++++++++++ sample/sample.cc | 4 ++-- src/BayesNet/BaseClassifier.h | 1 + src/BayesNet/Classifier.cc | 32 ++++++++++----------------- src/BayesNet/Classifier.h | 2 +- src/BayesNet/Ensemble.cc | 34 ++++++++++++++++++----------- src/BayesNet/Ensemble.h | 2 +- src/BayesNet/Network.cc | 41 +++++++++++++++++++++++++++++++++++ src/BayesNet/Network.h | 3 +++ 10 files changed, 95 insertions(+), 40 deletions(-) create mode 100644 TAN_iris.dot diff --git a/.vscode/launch.json b/.vscode/launch.json index 14f9cc8..4deb176 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,10 +12,10 @@ "-m", "TAN", "-p", - "../../data/", + "/Users/rmontanana/Code/discretizbench/datasets/", "--tensors" ], - "cwd": "${workspaceFolder}/build/sample/", + //"cwd": "${workspaceFolder}/build/sample/", }, { "type": "lldb", diff --git a/TAN_iris.dot b/TAN_iris.dot new file mode 100644 index 0000000..d3ce2cc --- /dev/null +++ b/TAN_iris.dot @@ -0,0 +1,12 @@ +digraph BayesNet { +label= +fontsize=30 +fontcolor=blue +labelloc=t +layout=circo + class [shape=circle, fontcolor=red, fillcolor=lightblue, style=filled ] + class -> sepallength class -> sepalwidth class -> petallength class -> petalwidth petallength [shape=circle] + petallength -> sepallength petalwidth [shape=circle] + sepallength [shape=circle] + sepallength -> sepalwidth sepalwidth [shape=circle] + sepalwidth -> petalwidth } diff --git a/sample/sample.cc b/sample/sample.cc index 2d7efa4..bd74edc 100644 --- a/sample/sample.cc +++ b/sample/sample.cc @@ -1,7 +1,6 @@ #include #include #include -#include #include #include #include "ArffFiles.h" @@ -42,7 +41,7 @@ bool file_exists(const std::string& name) } pair>, vector> extract_indices(vector indices, vector> X, vector y) { - vector> Xr; + vector> Xr; // nxm vector yr; for (int col = 0; col < X.size(); ++col) { Xr.push_back(vector()); @@ -199,6 +198,7 @@ int main(int argc, char** argv) torch::Tensor Xtestt = torch::index_select(Xt, 1, ttest); torch::Tensor ytestt = yt.index({ ttest }); clf->fit(Xtraint, ytraint, features, className, states); + auto temp = clf->predict(Xtraint); score_train = clf->score(Xtraint, ytraint); score_test = clf->score(Xtestt, ytestt); } else { diff --git a/src/BayesNet/BaseClassifier.h b/src/BayesNet/BaseClassifier.h index 00a60af..16daaa6 100644 --- a/src/BayesNet/BaseClassifier.h +++ b/src/BayesNet/BaseClassifier.h @@ -8,6 +8,7 @@ namespace bayesnet { public: virtual BaseClassifier& fit(vector>& X, vector& y, vector& features, string className, map>& states) = 0; virtual BaseClassifier& fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) = 0; + torch::Tensor virtual predict(torch::Tensor& X) = 0; vector virtual predict(vector>& X) = 0; float virtual score(vector>& X, vector& y) = 0; float virtual score(torch::Tensor& X, torch::Tensor& y) = 0; diff --git a/src/BayesNet/Classifier.cc b/src/BayesNet/Classifier.cc index 545525e..77624fc 100644 --- a/src/BayesNet/Classifier.cc +++ b/src/BayesNet/Classifier.cc @@ -65,23 +65,14 @@ namespace bayesnet { } } } - Tensor Classifier::predict(Tensor& X) { if (!fitted) { throw logic_error("Classifier has not been fitted"); } - auto m_ = X.size(0); - auto n_ = X.size(1); - //auto Xt = torch::transpose(X, 0, 1); - vector> Xd(n_, vector(m_, 0)); - for (auto i = 0; i < n_; i++) { - auto temp = X.index({ "...", i }); - Xd[i] = vector(temp.data_ptr(), temp.data_ptr() + temp.numel()); - } - auto yp = model.predict(Xd); - auto ypred = torch::tensor(yp, torch::kInt32); - return ypred; + auto Xt = torch::transpose(X, 0, 1); // Base classifiers expect samples as columns + auto y_proba = model.predict(Xt); + return y_proba.argmax(1); } vector Classifier::predict(vector>& X) { @@ -102,8 +93,7 @@ namespace bayesnet { if (!fitted) { throw logic_error("Classifier has not been fitted"); } - auto Xt = torch::transpose(X, 0, 1); - Tensor y_pred = predict(Xt); + Tensor y_pred = predict(X); return (y_pred == y).sum().item() / y.size(0); } float Classifier::score(vector>& X, vector& y) @@ -111,13 +101,13 @@ namespace bayesnet { if (!fitted) { throw logic_error("Classifier has not been fitted"); } - auto m_ = X[0].size(); - auto n_ = X.size(); - vector> Xd(n_, vector(m_, 0)); - for (auto i = 0; i < n_; i++) { - Xd[i] = vector(X[i].begin(), X[i].end()); - } - return model.score(Xd, y); + // auto m_ = X[0].size(); + // auto n_ = X.size(); + // vector> Xd(n_, vector(m_, 0)); + // for (auto i = 0; i < n_; i++) { + // Xd[i] = vector(X[i].begin(), X[i].end()); + // } + return model.score(X, y); } vector Classifier::show() { diff --git a/src/BayesNet/Classifier.h b/src/BayesNet/Classifier.h index ad56336..f5ee534 100644 --- a/src/BayesNet/Classifier.h +++ b/src/BayesNet/Classifier.h @@ -35,7 +35,7 @@ namespace bayesnet { int getNumberOfNodes() override; int getNumberOfEdges() override; int getNumberOfStates() override; - Tensor predict(Tensor& X); + Tensor predict(Tensor& X) override; vector predict(vector>& X) override; float score(Tensor& X, Tensor& y) override; float score(vector>& X, vector& y) override; diff --git a/src/BayesNet/Ensemble.cc b/src/BayesNet/Ensemble.cc index dce0d3d..3e4d1a6 100644 --- a/src/BayesNet/Ensemble.cc +++ b/src/BayesNet/Ensemble.cc @@ -16,15 +16,22 @@ namespace bayesnet { train(); // Train models n_models = models.size(); + auto Xt = torch::transpose(X, 0, 1); for (auto i = 0; i < n_models; ++i) { - models[i]->fit(Xv, yv, features, className, states); + if (Xv == vector>()) { + // fit with tensors + models[i]->fit(Xt, y, features, className, states); + } else { + // fit with vectors + models[i]->fit(Xv, yv, features, className, states); + } } fitted = true; return *this; } Ensemble& Ensemble::fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) { - this->X = X; + this->X = torch::transpose(X, 0, 1); this->y = y; Xv = vector>(); yv = vector(y.data_ptr(), y.data_ptr() + y.size(0)); @@ -41,17 +48,6 @@ namespace bayesnet { yv = y; return build(features, className, states); } - Tensor Ensemble::predict(Tensor& X) - { - if (!fitted) { - throw logic_error("Ensemble has not been fitted"); - } - Tensor y_pred = torch::zeros({ X.size(0), n_models }, kInt32); - for (auto i = 0; i < n_models; ++i) { - y_pred.index_put_({ "...", i }, models[i]->predict(X)); - } - return torch::tensor(voting(y_pred)); - } vector Ensemble::voting(Tensor& y_pred) { auto y_pred_ = y_pred.accessor(); @@ -66,6 +62,18 @@ namespace bayesnet { } return y_pred_final; } + Tensor Ensemble::predict(Tensor& X) + { + if (!fitted) { + throw logic_error("Ensemble has not been fitted"); + } + Tensor y_pred = torch::zeros({ X.size(1), n_models }, kInt32); + for (auto i = 0; i < n_models; ++i) { + auto ypredict = models[i]->predict(X); + y_pred.index_put_({ "...", i }, ypredict); + } + return torch::tensor(voting(y_pred)); + } vector Ensemble::predict(vector>& X) { if (!fitted) { diff --git a/src/BayesNet/Ensemble.h b/src/BayesNet/Ensemble.h index d45575d..2f85092 100644 --- a/src/BayesNet/Ensemble.h +++ b/src/BayesNet/Ensemble.h @@ -32,7 +32,7 @@ namespace bayesnet { virtual ~Ensemble() = default; Ensemble& fit(vector>& X, vector& y, vector& features, string className, map>& states) override; Ensemble& fit(torch::Tensor& X, torch::Tensor& y, vector& features, string className, map>& states) override; - Tensor predict(Tensor& X); + Tensor predict(Tensor& X) override; vector predict(vector>& X) override; float score(Tensor& X, Tensor& y) override; float score(vector>& X, vector& y) override; diff --git a/src/BayesNet/Network.cc b/src/BayesNet/Network.cc index 35b3cc5..f46e1e9 100644 --- a/src/BayesNet/Network.cc +++ b/src/BayesNet/Network.cc @@ -170,6 +170,34 @@ namespace bayesnet { } fitted = true; } + Tensor Network::predict_proba(const Tensor& samples) + { + if (!fitted) { + throw logic_error("You must call fit() before calling predict_proba()"); + } + Tensor result = torch::zeros({ samples.size(0), classNumStates }, torch::kFloat64); + auto Xt = torch::transpose(samples, 0, 1); + for (int i = 0; i < samples.size(0); ++i) { + auto sample = Xt.index({ "...", i }); + auto classProbabilities = predict_sample(sample); + result.index_put_({ i, "..." }, torch::tensor(classProbabilities, torch::kFloat64)); + } + return result; + } + Tensor Network::predict(const Tensor& samples) + { + if (!fitted) { + throw logic_error("You must call fit() before calling predict()"); + } + Tensor result = torch::zeros({ samples.size(0), classNumStates }, torch::kFloat64); + auto Xt = torch::transpose(samples, 0, 1); + for (int i = 0; i < samples.size(0); ++i) { + auto sample = Xt.index({ "...", i }); + auto classProbabilities = predict_sample(sample); + result.index_put_({ i, "..." }, torch::tensor(classProbabilities, torch::kFloat64)); + } + return result; + } vector Network::predict(const vector>& tsamples) { @@ -231,6 +259,19 @@ namespace bayesnet { } return exactInference(evidence); } + vector Network::predict_sample(const Tensor& sample) + { + // Ensure the sample size is equal to the number of features + if (sample.size(0) != features.size()) { + throw invalid_argument("Sample size (" + to_string(sample.size(0)) + + ") does not match the number of features (" + to_string(features.size()) + ")"); + } + map evidence; + for (int i = 0; i < sample.size(0); ++i) { + evidence[features[i]] = sample[i].item(); + } + return exactInference(evidence); + } double Network::computeFactor(map& completeEvidence) { double result = 1.0; diff --git a/src/BayesNet/Network.h b/src/BayesNet/Network.h index abd02b4..f763dde 100644 --- a/src/BayesNet/Network.h +++ b/src/BayesNet/Network.h @@ -18,6 +18,7 @@ namespace bayesnet { torch::Tensor samples; bool isCyclic(const std::string&, std::unordered_set&, std::unordered_set&); vector predict_sample(const vector&); + vector predict_sample(const torch::Tensor&); vector exactInference(map&); double computeFactor(map&); double mutual_info(torch::Tensor&, torch::Tensor&); @@ -43,9 +44,11 @@ namespace bayesnet { void fit(const vector>&, const vector&, const vector&, const string&); void fit(torch::Tensor&, torch::Tensor&, const vector&, const string&); vector predict(const vector>&); + torch::Tensor predict(const torch::Tensor&); //Computes the conditional edge weight of variable index u and v conditioned on class_node torch::Tensor conditionalEdgeWeight(); vector> predict_proba(const vector>&); + torch::Tensor predict_proba(const torch::Tensor&); double score(const vector>&, const vector&); vector show(); vector graph(const string& title); // Returns a vector of strings representing the graph in graphviz format