continue fixing xspode

This commit is contained in:
2025-03-10 12:18:10 +01:00
parent 6cfbc482d8
commit 7a8e0391dc
4 changed files with 26 additions and 40 deletions

View File

@@ -10,6 +10,7 @@
#include <stdexcept>
#include <sstream>
#include "XSPODE.h"
#include "bayesnet/utils/TensorUtils.h"
namespace bayesnet {
@@ -299,30 +300,6 @@ namespace bayesnet {
return probabilities;
}
// --------------------------------------
// predict
// --------------------------------------
//
// Return the class argmax( P(c|x) ).
// --------------------------------------
int XSpode::predict(const std::vector<int>& instance) const
{
auto p = predict_proba(instance);
return static_cast<int>(std::distance(p.begin(),
std::max_element(p.begin(), p.end())));
}
std::vector<int> XSpode::predict(std::vector<std::vector<int>>& test_data)
{
auto probabilities = predict_proba(test_data);
std::vector<int> predictions(probabilities.size(), 0);
for (size_t i = 0; i < probabilities.size(); i++) {
predictions[i] = std::distance(probabilities[i].begin(), std::max_element(probabilities[i].begin(), probabilities[i].end()));
}
return predictions;
}
// --------------------------------------
// Utility: normalize
// --------------------------------------
@@ -393,26 +370,36 @@ namespace bayesnet {
// ------------------------------------------------------
// Predict overrides (classifier interface)
// ------------------------------------------------------
torch::Tensor predict(torch::Tensor& X)
int XSpode::predict(const std::vector<int>& instance) const
{
auto X_ = TensorUtils::to_matrix(X);
return predict(X_);
auto p = predict_proba(instance);
return static_cast<int>(std::distance(p.begin(),
std::max_element(p.begin(), p.end())));
}
std::vector<int> predict(std::vector<std::vector<int>>& X)
std::vector<int> XSpode::predict(std::vector<std::vector<int>>& test_data)
{
auto proba = predict_proba(X);
std::vector<int> predictions(proba.size(), 0);
for (size_t i = 0; i < proba.size(); i++) {
predictions[i] = std::distance(proba[i].begin(), std::max_element(proba[i].begin(), proba[i].end()));
auto probabilities = predict_proba(test_data);
std::vector<int> predictions(probabilities.size(), 0);
for (size_t i = 0; i < probabilities.size(); i++) {
predictions[i] = std::distance(probabilities[i].begin(), std::max_element(probabilities[i].begin(), probabilities[i].end()));
}
return predictions;
}
torch::Tensor predict_proba(torch::Tensor& X)
torch::Tensor XSpode::predict(torch::Tensor& X)
{
auto X_ = TensorUtils::to_matrix(X);
auto result = predict(X_);
return TensorUtils::to_tensor(result);
}
torch::Tensor XSpode::predict_proba(torch::Tensor& X)
{
auto X_ = TensorUtils::to_matrix(X);
return predict_proba(X_);
auto result = predict_proba(X_);
return TensorUtils::to_tensor<double>(result);
}
torch::Tensor Classifier::predict(torch::Tensor& X)
torch::Tensor XSpode::predict(torch::Tensor& X)
{
auto X_ = TensorUtils::to_matrix(X);
auto predict = predict(X_);