continue fixing xspode

This commit is contained in:
2025-03-10 12:18:10 +01:00
parent 6cfbc482d8
commit 7a8e0391dc
4 changed files with 26 additions and 40 deletions

View File

@@ -10,6 +10,7 @@
#include <stdexcept> #include <stdexcept>
#include <sstream> #include <sstream>
#include "XSPODE.h" #include "XSPODE.h"
#include "bayesnet/utils/TensorUtils.h"
namespace bayesnet { namespace bayesnet {
@@ -299,30 +300,6 @@ namespace bayesnet {
return probabilities; return probabilities;
} }
// --------------------------------------
// predict
// --------------------------------------
//
// Return the class argmax( P(c|x) ).
// --------------------------------------
int XSpode::predict(const std::vector<int>& instance) const
{
auto p = predict_proba(instance);
return static_cast<int>(std::distance(p.begin(),
std::max_element(p.begin(), p.end())));
}
std::vector<int> XSpode::predict(std::vector<std::vector<int>>& test_data)
{
auto probabilities = predict_proba(test_data);
std::vector<int> predictions(probabilities.size(), 0);
for (size_t i = 0; i < probabilities.size(); i++) {
predictions[i] = std::distance(probabilities[i].begin(), std::max_element(probabilities[i].begin(), probabilities[i].end()));
}
return predictions;
}
// -------------------------------------- // --------------------------------------
// Utility: normalize // Utility: normalize
// -------------------------------------- // --------------------------------------
@@ -393,26 +370,36 @@ namespace bayesnet {
// ------------------------------------------------------ // ------------------------------------------------------
// Predict overrides (classifier interface) // Predict overrides (classifier interface)
// ------------------------------------------------------ // ------------------------------------------------------
torch::Tensor predict(torch::Tensor& X) int XSpode::predict(const std::vector<int>& instance) const
{ {
auto X_ = TensorUtils::to_matrix(X); auto p = predict_proba(instance);
return predict(X_); return static_cast<int>(std::distance(p.begin(),
std::max_element(p.begin(), p.end())));
} }
std::vector<int> predict(std::vector<std::vector<int>>& X) std::vector<int> XSpode::predict(std::vector<std::vector<int>>& test_data)
{ {
auto proba = predict_proba(X); auto probabilities = predict_proba(test_data);
std::vector<int> predictions(proba.size(), 0); std::vector<int> predictions(probabilities.size(), 0);
for (size_t i = 0; i < proba.size(); i++) {
predictions[i] = std::distance(proba[i].begin(), std::max_element(proba[i].begin(), proba[i].end())); for (size_t i = 0; i < probabilities.size(); i++) {
predictions[i] = std::distance(probabilities[i].begin(), std::max_element(probabilities[i].begin(), probabilities[i].end()));
} }
return predictions; return predictions;
} }
torch::Tensor predict_proba(torch::Tensor& X) torch::Tensor XSpode::predict(torch::Tensor& X)
{
auto X_ = TensorUtils::to_matrix(X);
auto result = predict(X_);
return TensorUtils::to_tensor(result);
}
torch::Tensor XSpode::predict_proba(torch::Tensor& X)
{ {
auto X_ = TensorUtils::to_matrix(X); auto X_ = TensorUtils::to_matrix(X);
return predict_proba(X_); auto result = predict_proba(X_);
return TensorUtils::to_tensor<double>(result);
} }
torch::Tensor Classifier::predict(torch::Tensor& X) torch::Tensor XSpode::predict(torch::Tensor& X)
{ {
auto X_ = TensorUtils::to_matrix(X); auto X_ = TensorUtils::to_matrix(X);
auto predict = predict(X_); auto predict = predict(X_);

View File

@@ -20,7 +20,6 @@ namespace bayesnet {
std::vector<double> predict_proba(const std::vector<int>& instance) const; std::vector<double> predict_proba(const std::vector<int>& instance) const;
std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>& test_data); std::vector<std::vector<double>> predict_proba(const std::vector<std::vector<int>>& test_data);
int predict(const std::vector<int>& instance) const; int predict(const std::vector<int>& instance) const;
std::vector<int> predict(std::vector<std::vector<int>>& test_data);
void normalize(std::vector<double>& v) const; void normalize(std::vector<double>& v) const;
std::string to_string() const; std::string to_string() const;
int getNFeatures() const; int getNFeatures() const;

View File

@@ -10,7 +10,7 @@
#include <tuple> #include <tuple>
#include "XBAODE.h" #include "XBAODE.h"
#include "bayesnet/classifiers/XSPODE.h" #include "bayesnet/classifiers/XSPODE.h"
#include "bayesnet/utils/TensorUtils.hpp" #include "bayesnet/utils/TensorUtils.h"
namespace bayesnet { namespace bayesnet {
XBAODE::XBAODE() XBAODE::XBAODE()

View File

@@ -1,5 +1,5 @@
#ifndef TENSORUTILS_HPP #ifndef TENSORUTILS_H
#define TENSORUTILS_HPP #define TENSORUTILS_H
#include <torch/torch.h> #include <torch/torch.h>
#include <vector> #include <vector>
namespace bayesnet { namespace bayesnet {
@@ -48,4 +48,4 @@ namespace bayesnet {
}; };
} }
#endif // TENSORUTILS_HPP #endif // TENSORUTILS_H