From 7a8e0391dc8501a540181499f8f977c0737c8ba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 10 Mar 2025 12:18:10 +0100 Subject: [PATCH] continue fixing xspode --- bayesnet/classifiers/XSPODE.cc | 57 +++++++------------ bayesnet/classifiers/XSPODE.h | 1 - bayesnet/ensembles/XBAODE.cc | 2 +- .../utils/{TensorUtils.hpp => TensorUtils.h} | 6 +- 4 files changed, 26 insertions(+), 40 deletions(-) rename bayesnet/utils/{TensorUtils.hpp => TensorUtils.h} (95%) diff --git a/bayesnet/classifiers/XSPODE.cc b/bayesnet/classifiers/XSPODE.cc index a139747..4449849 100644 --- a/bayesnet/classifiers/XSPODE.cc +++ b/bayesnet/classifiers/XSPODE.cc @@ -10,6 +10,7 @@ #include #include #include "XSPODE.h" +#include "bayesnet/utils/TensorUtils.h" namespace bayesnet { @@ -299,30 +300,6 @@ namespace bayesnet { return probabilities; } - // -------------------------------------- - // predict - // -------------------------------------- - // - // Return the class argmax( P(c|x) ). - // -------------------------------------- - int XSpode::predict(const std::vector& instance) const - { - auto p = predict_proba(instance); - return static_cast(std::distance(p.begin(), - std::max_element(p.begin(), p.end()))); - } - std::vector XSpode::predict(std::vector>& test_data) - { - auto probabilities = predict_proba(test_data); - std::vector predictions(probabilities.size(), 0); - - for (size_t i = 0; i < probabilities.size(); i++) { - predictions[i] = std::distance(probabilities[i].begin(), std::max_element(probabilities[i].begin(), probabilities[i].end())); - } - - return predictions; - } - // -------------------------------------- // Utility: normalize // -------------------------------------- @@ -393,26 +370,36 @@ namespace bayesnet { // ------------------------------------------------------ // Predict overrides (classifier interface) // ------------------------------------------------------ - torch::Tensor predict(torch::Tensor& X) + int XSpode::predict(const std::vector& instance) const { - auto X_ = TensorUtils::to_matrix(X); - return predict(X_); + auto p = predict_proba(instance); + return static_cast(std::distance(p.begin(), + std::max_element(p.begin(), p.end()))); } - std::vector predict(std::vector>& X) + std::vector XSpode::predict(std::vector>& test_data) { - auto proba = predict_proba(X); - std::vector predictions(proba.size(), 0); - for (size_t i = 0; i < proba.size(); i++) { - predictions[i] = std::distance(proba[i].begin(), std::max_element(proba[i].begin(), proba[i].end())); + auto probabilities = predict_proba(test_data); + std::vector predictions(probabilities.size(), 0); + + for (size_t i = 0; i < probabilities.size(); i++) { + predictions[i] = std::distance(probabilities[i].begin(), std::max_element(probabilities[i].begin(), probabilities[i].end())); } + return predictions; } - torch::Tensor predict_proba(torch::Tensor& X) + torch::Tensor XSpode::predict(torch::Tensor& X) + { + auto X_ = TensorUtils::to_matrix(X); + auto result = predict(X_); + return TensorUtils::to_tensor(result); + } + torch::Tensor XSpode::predict_proba(torch::Tensor& X) { auto X_ = TensorUtils::to_matrix(X); - return predict_proba(X_); + auto result = predict_proba(X_); + return TensorUtils::to_tensor(result); } - torch::Tensor Classifier::predict(torch::Tensor& X) + torch::Tensor XSpode::predict(torch::Tensor& X) { auto X_ = TensorUtils::to_matrix(X); auto predict = predict(X_); diff --git a/bayesnet/classifiers/XSPODE.h b/bayesnet/classifiers/XSPODE.h index 03848de..7f57b22 100644 --- a/bayesnet/classifiers/XSPODE.h +++ b/bayesnet/classifiers/XSPODE.h @@ -20,7 +20,6 @@ namespace bayesnet { std::vector predict_proba(const std::vector& instance) const; std::vector> predict_proba(const std::vector>& test_data); int predict(const std::vector& instance) const; - std::vector predict(std::vector>& test_data); void normalize(std::vector& v) const; std::string to_string() const; int getNFeatures() const; diff --git a/bayesnet/ensembles/XBAODE.cc b/bayesnet/ensembles/XBAODE.cc index bc8f657..982dbff 100644 --- a/bayesnet/ensembles/XBAODE.cc +++ b/bayesnet/ensembles/XBAODE.cc @@ -10,7 +10,7 @@ #include #include "XBAODE.h" #include "bayesnet/classifiers/XSPODE.h" -#include "bayesnet/utils/TensorUtils.hpp" +#include "bayesnet/utils/TensorUtils.h" namespace bayesnet { XBAODE::XBAODE() diff --git a/bayesnet/utils/TensorUtils.hpp b/bayesnet/utils/TensorUtils.h similarity index 95% rename from bayesnet/utils/TensorUtils.hpp rename to bayesnet/utils/TensorUtils.h index dffd879..1834051 100644 --- a/bayesnet/utils/TensorUtils.hpp +++ b/bayesnet/utils/TensorUtils.h @@ -1,5 +1,5 @@ -#ifndef TENSORUTILS_HPP -#define TENSORUTILS_HPP +#ifndef TENSORUTILS_H +#define TENSORUTILS_H #include #include namespace bayesnet { @@ -48,4 +48,4 @@ namespace bayesnet { }; } -#endif // TENSORUTILS_HPP \ No newline at end of file +#endif // TENSORUTILS_H \ No newline at end of file