Fix XSpode predict
This commit is contained in:
@@ -22,8 +22,11 @@ namespace bayesnet {
|
||||
auto n_classes = states.at(className).size();
|
||||
metrics = Metrics(dataset, features, className, n_classes);
|
||||
model.initialize();
|
||||
std::cout << "Ahora buildmodel"<< std::endl;
|
||||
buildModel(weights);
|
||||
std::cout << "Ahora trainmodel"<< std::endl;
|
||||
trainModel(weights, smoothing);
|
||||
std::cout << "Después de trainmodel"<< std::endl;
|
||||
fitted = true;
|
||||
return *this;
|
||||
}
|
||||
|
@@ -3,7 +3,12 @@
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
#include <stdexcept>
|
||||
#include <sstream>
|
||||
#include "XSPODE.h"
|
||||
|
||||
|
||||
@@ -20,6 +25,17 @@ namespace bayesnet {
|
||||
initializer_{ 1.0 },
|
||||
semaphore_{ CountingSemaphore::getInstance() }, Classifier(Network())
|
||||
{
|
||||
validHyperparameters = { "parent" };
|
||||
}
|
||||
|
||||
void XSpode::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("parent")) {
|
||||
superParent_ = hyperparameters["parent"];
|
||||
hyperparameters.erase("parent");
|
||||
}
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
}
|
||||
|
||||
void XSpode::fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing)
|
||||
@@ -28,6 +44,7 @@ namespace bayesnet {
|
||||
n = X.size();
|
||||
buildModel(weights_);
|
||||
trainModel(weights_, smoothing);
|
||||
fitted=true;
|
||||
}
|
||||
|
||||
// --------------------------------------
|
||||
@@ -89,7 +106,7 @@ namespace bayesnet {
|
||||
for (int f = 0; f < nFeatures_; f++) {
|
||||
instance[f] = dataset[f][i].item<int>();
|
||||
}
|
||||
instance[nFeatures_] = dataset[-1].item<int>();
|
||||
instance[nFeatures_] = dataset[-1][i].item<int>();
|
||||
addSample(instance, weights[i].item<double>());
|
||||
}
|
||||
|
||||
@@ -205,7 +222,6 @@ namespace bayesnet {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// --------------------------------------
|
||||
@@ -218,8 +234,10 @@ namespace bayesnet {
|
||||
// --------------------------------------
|
||||
std::vector<double> XSpode::predict_proba(const std::vector<int>& instance) const
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::logic_error(CLASSIFIER_NOT_FITTED);
|
||||
}
|
||||
std::vector<double> probs(statesClass_, 0.0);
|
||||
|
||||
// Multiply p(c) × p(x_sp | c)
|
||||
int spVal = instance[superParent_];
|
||||
for (int c = 0; c < statesClass_; c++) {
|
||||
@@ -295,9 +313,6 @@ namespace bayesnet {
|
||||
}
|
||||
std::vector<int> XSpode::predict(std::vector<std::vector<int>>& test_data)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::logic_error(CLASSIFIER_NOT_FITTED);
|
||||
}
|
||||
auto probabilities = predict_proba(test_data);
|
||||
std::vector<int> predictions(probabilities.size(), 0);
|
||||
|
||||
@@ -375,5 +390,34 @@ namespace bayesnet {
|
||||
}
|
||||
std::vector<int>& XSpode::getStates() { return states_; }
|
||||
|
||||
// ------------------------------------------------------
|
||||
// Predict overrides (classifier interface)
|
||||
// ------------------------------------------------------
|
||||
torch::Tensor predict(torch::Tensor& X)
|
||||
{
|
||||
auto X_ = TensorUtils::to_matrix(X);
|
||||
return predict(X_);
|
||||
}
|
||||
std::vector<int> predict(std::vector<std::vector<int>>& X)
|
||||
{
|
||||
auto proba = predict_proba(X);
|
||||
std::vector<int> predictions(proba.size(), 0);
|
||||
for (size_t i = 0; i < proba.size(); i++) {
|
||||
predictions[i] = std::distance(proba[i].begin(), std::max_element(proba[i].begin(), proba[i].end()));
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
torch::Tensor predict_proba(torch::Tensor& X)
|
||||
{
|
||||
auto X_ = TensorUtils::to_matrix(X);
|
||||
return predict_proba(X_);
|
||||
}
|
||||
torch::Tensor Classifier::predict(torch::Tensor& X)
|
||||
{
|
||||
auto X_ = TensorUtils::to_matrix(X);
|
||||
auto predict = predict(X_);
|
||||
return TensorUtils::to_tensor(predict);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@@ -8,15 +8,6 @@
|
||||
#define XSPODE_H
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <stdexcept>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <torch/torch.h>
|
||||
#include "Classifier.h"
|
||||
#include "bayesnet/utils/CountingSemaphore.h"
|
||||
@@ -32,7 +23,6 @@ namespace bayesnet {
|
||||
std::vector<int> predict(std::vector<std::vector<int>>& test_data);
|
||||
void normalize(std::vector<double>& v) const;
|
||||
std::string to_string() const;
|
||||
int statesClass() const;
|
||||
int getNFeatures() const;
|
||||
int getNumberOfNodes() const override;
|
||||
int getNumberOfEdges() const override;
|
||||
@@ -41,6 +31,15 @@ namespace bayesnet {
|
||||
std::vector<int>& getStates();
|
||||
std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({title}); }
|
||||
void fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing);
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||
|
||||
//
|
||||
// Classifier interface
|
||||
//
|
||||
torch::Tensor predict(torch::Tensor& X) override;
|
||||
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
|
||||
torch::Tensor predict_proba(torch::Tensor& X) override;
|
||||
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
|
||||
protected:
|
||||
void buildModel(const torch::Tensor& weights) override;
|
||||
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;
|
||||
|
Reference in New Issue
Block a user