From e1c4221c115897c666af69e99927101b9a76a143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Tue, 20 Feb 2024 10:58:21 +0100 Subject: [PATCH] Add predict_voting and predict_prob to ensemble --- src/BayesNet/Ensemble.cc | 66 +++++++++++++++++++++++++++++++--------- src/BayesNet/Ensemble.h | 22 ++++++++------ 2 files changed, 64 insertions(+), 24 deletions(-) diff --git a/src/BayesNet/Ensemble.cc b/src/BayesNet/Ensemble.cc index 4702733..2eedc8d 100644 --- a/src/BayesNet/Ensemble.cc +++ b/src/BayesNet/Ensemble.cc @@ -2,8 +2,8 @@ namespace bayesnet { - Ensemble::Ensemble() : Classifier(Network()), n_models(0) {} - + Ensemble::Ensemble(bool predict_voting) : Classifier(Network()), n_models(0), predict_voting(predict_voting) {}; + const std::string ENSEMBLE_NOT_FITTED = "Ensemble has not been fitted"; void Ensemble::trainModel(const torch::Tensor& weights) { n_models = models.size(); @@ -12,6 +12,7 @@ namespace bayesnet { models[i]->fit(dataset, features, className, states); } } + std::vector Ensemble::voting(torch::Tensor& y_pred) { auto y_pred_ = y_pred.accessor(); @@ -31,11 +32,55 @@ namespace bayesnet { } return y_pred_final; } + std::vector Ensemble::predict(std::vector>& X) + { + if (!fitted) { + throw std::logic_error(ENSEMBLE_NOT_FITTED); + } + return predict_voting ? do_predict_voting(X) : do_predict_prob(X); + + } torch::Tensor Ensemble::predict(torch::Tensor& X) { if (!fitted) { - throw std::logic_error("Ensemble has not been fitted"); + throw std::logic_error(ENSEMBLE_NOT_FITTED); } + return predict_voting ? do_predict_voting(X) : do_predict_prob(X); + } + torch::Tensor Ensemble::do_predict_prob(torch::Tensor& X) + { + torch::Tensor y_pred = torch::zeros({ X.size(1), n_models }, torch::kFloat32); + // auto threads{ std::vector() }; + // std::mutex mtx; + // for (auto i = 0; i < n_models; ++i) { + // threads.push_back(std::thread([&, i]() { + // auto ypredict = models[i]->predict(X); + // std::lock_guard lock(mtx); + // y_pred.index_put_({ "...", i }, ypredict); + // })); + // } + // for (auto& thread : threads) { + // thread.join(); + // } + return y_pred; + } + std::vector Ensemble::do_predict_prob(std::vector>& X) + { + // long m_ = X[0].size(); + // long n_ = X.size(); + // vector> Xd(n_, vector(m_, 0)); + // for (auto i = 0; i < n_; i++) { + // Xd[i] = vector(X[i].begin(), X[i].end()); + // } + // torch::Tensor y_pred = torch::zeros({ m_, n_models }, torch::kInt32); + // for (auto i = 0; i < n_models; ++i) { + // y_pred.index_put_({ "...", i }, torch::tensor(models[i]->predict(Xd), torch::kInt32)); + // } + // return voting(y_pred); + return std::vector(); + } + torch::Tensor Ensemble::do_predict_voting(torch::Tensor& X) + { torch::Tensor y_pred = torch::zeros({ X.size(1), n_models }, torch::kInt32); auto threads{ std::vector() }; std::mutex mtx; @@ -51,11 +96,8 @@ namespace bayesnet { } return torch::tensor(voting(y_pred)); } - std::vector Ensemble::predict(std::vector>& X) + std::vector Ensemble::do_predict_voting(std::vector>& X) { - if (!fitted) { - throw std::logic_error("Ensemble has not been fitted"); - } long m_ = X[0].size(); long n_ = X.size(); std::vector> Xd(n_, std::vector(m_, 0)); @@ -70,10 +112,7 @@ namespace bayesnet { } float Ensemble::score(torch::Tensor& X, torch::Tensor& y) { - if (!fitted) { - throw std::logic_error("Ensemble has not been fitted"); - } - auto y_pred = predict(X); + auto y_pred = predict_voting ? do_predict_voting(X) : do_predict_prob(X); int correct = 0; for (int i = 0; i < y_pred.size(0); ++i) { if (y_pred[i].item() == y[i].item()) { @@ -84,10 +123,7 @@ namespace bayesnet { } float Ensemble::score(std::vector>& X, std::vector& y) { - if (!fitted) { - throw std::logic_error("Ensemble has not been fitted"); - } - auto y_pred = predict(X); + auto y_pred = predict_voting ? do_predict_voting(X) : do_predict_prob(X); int correct = 0; for (int i = 0; i < y_pred.size(); ++i) { if (y_pred[i] == y[i]) { diff --git a/src/BayesNet/Ensemble.h b/src/BayesNet/Ensemble.h index 07fda9b..b748235 100644 --- a/src/BayesNet/Ensemble.h +++ b/src/BayesNet/Ensemble.h @@ -7,19 +7,15 @@ namespace bayesnet { class Ensemble : public Classifier { - private: - Ensemble& build(std::vector& features, std::string className, std::map>& states); - protected: - unsigned n_models; - std::vector> models; - std::vector significanceModels; - void trainModel(const torch::Tensor& weights) override; - std::vector voting(torch::Tensor& y_pred); public: - Ensemble(); + Ensemble(bool predict_voting = true); virtual ~Ensemble() = default; torch::Tensor predict(torch::Tensor& X) override; std::vector predict(std::vector>& X) override; + torch::Tensor do_predict_voting(torch::Tensor& X); + std::vector do_predict_voting(std::vector>& X); + torch::Tensor do_predict_prob(torch::Tensor& X); + std::vector do_predict_prob(std::vector>& X); float score(torch::Tensor& X, torch::Tensor& y) override; float score(std::vector>& X, std::vector& y) override; int getNumberOfNodes() const override; @@ -34,6 +30,14 @@ namespace bayesnet { void dump_cpt() const override { } + protected: + unsigned n_models; + std::vector> models; + std::vector significanceModels; + void trainModel(const torch::Tensor& weights) override; + std::vector voting(torch::Tensor& y_pred); + private: + bool predict_voting; }; } #endif