Fix XSpode predict

This commit is contained in:
2025-03-10 11:18:04 +01:00
parent 06621ea361
commit ca54f799ee
6 changed files with 310 additions and 23 deletions

View File

@@ -22,8 +22,11 @@ namespace bayesnet {
auto n_classes = states.at(className).size();
metrics = Metrics(dataset, features, className, n_classes);
model.initialize();
std::cout << "Ahora buildmodel"<< std::endl;
buildModel(weights);
std::cout << "Ahora trainmodel"<< std::endl;
trainModel(weights, smoothing);
std::cout << "Después de trainmodel"<< std::endl;
fitted = true;
return *this;
}

View File

@@ -3,7 +3,12 @@
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#include <limits>
#include <algorithm>
#include <numeric>
#include <cmath>
#include <stdexcept>
#include <sstream>
#include "XSPODE.h"
@@ -20,6 +25,17 @@ namespace bayesnet {
initializer_{ 1.0 },
semaphore_{ CountingSemaphore::getInstance() }, Classifier(Network())
{
validHyperparameters = { "parent" };
}
void XSpode::setHyperparameters(const nlohmann::json& hyperparameters_)
{
auto hyperparameters = hyperparameters_;
if (hyperparameters.contains("parent")) {
superParent_ = hyperparameters["parent"];
hyperparameters.erase("parent");
}
Classifier::setHyperparameters(hyperparameters);
}
void XSpode::fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing)
@@ -28,6 +44,7 @@ namespace bayesnet {
n = X.size();
buildModel(weights_);
trainModel(weights_, smoothing);
fitted=true;
}
// --------------------------------------
@@ -89,7 +106,7 @@ namespace bayesnet {
for (int f = 0; f < nFeatures_; f++) {
instance[f] = dataset[f][i].item<int>();
}
instance[nFeatures_] = dataset[-1].item<int>();
instance[nFeatures_] = dataset[-1][i].item<int>();
addSample(instance, weights[i].item<double>());
}
@@ -205,7 +222,6 @@ namespace bayesnet {
}
}
}
}
// --------------------------------------
@@ -218,8 +234,10 @@ namespace bayesnet {
// --------------------------------------
std::vector<double> XSpode::predict_proba(const std::vector<int>& instance) const
{
if (!fitted) {
throw std::logic_error(CLASSIFIER_NOT_FITTED);
}
std::vector<double> probs(statesClass_, 0.0);
// Multiply p(c) × p(x_sp | c)
int spVal = instance[superParent_];
for (int c = 0; c < statesClass_; c++) {
@@ -295,9 +313,6 @@ namespace bayesnet {
}
std::vector<int> XSpode::predict(std::vector<std::vector<int>>& test_data)
{
if (!fitted) {
throw std::logic_error(CLASSIFIER_NOT_FITTED);
}
auto probabilities = predict_proba(test_data);
std::vector<int> predictions(probabilities.size(), 0);
@@ -375,5 +390,34 @@ namespace bayesnet {
}
std::vector<int>& XSpode::getStates() { return states_; }
// ------------------------------------------------------
// Predict overrides (classifier interface)
// ------------------------------------------------------
torch::Tensor predict(torch::Tensor& X)
{
auto X_ = TensorUtils::to_matrix(X);
return predict(X_);
}
std::vector<int> predict(std::vector<std::vector<int>>& X)
{
auto proba = predict_proba(X);
std::vector<int> predictions(proba.size(), 0);
for (size_t i = 0; i < proba.size(); i++) {
predictions[i] = std::distance(proba[i].begin(), std::max_element(proba[i].begin(), proba[i].end()));
}
return predictions;
}
torch::Tensor predict_proba(torch::Tensor& X)
{
auto X_ = TensorUtils::to_matrix(X);
return predict_proba(X_);
}
torch::Tensor Classifier::predict(torch::Tensor& X)
{
auto X_ = TensorUtils::to_matrix(X);
auto predict = predict(X_);
return TensorUtils::to_tensor(predict);
}
}

View File

@@ -8,15 +8,6 @@
#define XSPODE_H
#include <vector>
#include <map>
#include <stdexcept>
#include <algorithm>
#include <numeric>
#include <string>
#include <cmath>
#include <limits>
#include <sstream>
#include <iostream>
#include <torch/torch.h>
#include "Classifier.h"
#include "bayesnet/utils/CountingSemaphore.h"
@@ -32,7 +23,6 @@ namespace bayesnet {
std::vector<int> predict(std::vector<std::vector<int>>& test_data);
void normalize(std::vector<double>& v) const;
std::string to_string() const;
int statesClass() const;
int getNFeatures() const;
int getNumberOfNodes() const override;
int getNumberOfEdges() const override;
@@ -41,6 +31,15 @@ namespace bayesnet {
std::vector<int>& getStates();
std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({title}); }
void fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing);
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
//
// Classifier interface
//
torch::Tensor predict(torch::Tensor& X) override;
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
torch::Tensor predict_proba(torch::Tensor& X) override;
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
protected:
void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;