Fix XSpode predict
This commit is contained in:
@@ -3,7 +3,12 @@
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
#include <stdexcept>
|
||||
#include <sstream>
|
||||
#include "XSPODE.h"
|
||||
|
||||
|
||||
@@ -20,6 +25,17 @@ namespace bayesnet {
|
||||
initializer_{ 1.0 },
|
||||
semaphore_{ CountingSemaphore::getInstance() }, Classifier(Network())
|
||||
{
|
||||
validHyperparameters = { "parent" };
|
||||
}
|
||||
|
||||
void XSpode::setHyperparameters(const nlohmann::json& hyperparameters_)
|
||||
{
|
||||
auto hyperparameters = hyperparameters_;
|
||||
if (hyperparameters.contains("parent")) {
|
||||
superParent_ = hyperparameters["parent"];
|
||||
hyperparameters.erase("parent");
|
||||
}
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
}
|
||||
|
||||
void XSpode::fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing)
|
||||
@@ -28,6 +44,7 @@ namespace bayesnet {
|
||||
n = X.size();
|
||||
buildModel(weights_);
|
||||
trainModel(weights_, smoothing);
|
||||
fitted=true;
|
||||
}
|
||||
|
||||
// --------------------------------------
|
||||
@@ -89,7 +106,7 @@ namespace bayesnet {
|
||||
for (int f = 0; f < nFeatures_; f++) {
|
||||
instance[f] = dataset[f][i].item<int>();
|
||||
}
|
||||
instance[nFeatures_] = dataset[-1].item<int>();
|
||||
instance[nFeatures_] = dataset[-1][i].item<int>();
|
||||
addSample(instance, weights[i].item<double>());
|
||||
}
|
||||
|
||||
@@ -205,7 +222,6 @@ namespace bayesnet {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// --------------------------------------
|
||||
@@ -218,8 +234,10 @@ namespace bayesnet {
|
||||
// --------------------------------------
|
||||
std::vector<double> XSpode::predict_proba(const std::vector<int>& instance) const
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::logic_error(CLASSIFIER_NOT_FITTED);
|
||||
}
|
||||
std::vector<double> probs(statesClass_, 0.0);
|
||||
|
||||
// Multiply p(c) × p(x_sp | c)
|
||||
int spVal = instance[superParent_];
|
||||
for (int c = 0; c < statesClass_; c++) {
|
||||
@@ -295,9 +313,6 @@ namespace bayesnet {
|
||||
}
|
||||
std::vector<int> XSpode::predict(std::vector<std::vector<int>>& test_data)
|
||||
{
|
||||
if (!fitted) {
|
||||
throw std::logic_error(CLASSIFIER_NOT_FITTED);
|
||||
}
|
||||
auto probabilities = predict_proba(test_data);
|
||||
std::vector<int> predictions(probabilities.size(), 0);
|
||||
|
||||
@@ -375,5 +390,34 @@ namespace bayesnet {
|
||||
}
|
||||
std::vector<int>& XSpode::getStates() { return states_; }
|
||||
|
||||
// ------------------------------------------------------
|
||||
// Predict overrides (classifier interface)
|
||||
// ------------------------------------------------------
|
||||
torch::Tensor predict(torch::Tensor& X)
|
||||
{
|
||||
auto X_ = TensorUtils::to_matrix(X);
|
||||
return predict(X_);
|
||||
}
|
||||
std::vector<int> predict(std::vector<std::vector<int>>& X)
|
||||
{
|
||||
auto proba = predict_proba(X);
|
||||
std::vector<int> predictions(proba.size(), 0);
|
||||
for (size_t i = 0; i < proba.size(); i++) {
|
||||
predictions[i] = std::distance(proba[i].begin(), std::max_element(proba[i].begin(), proba[i].end()));
|
||||
}
|
||||
return predictions;
|
||||
}
|
||||
torch::Tensor predict_proba(torch::Tensor& X)
|
||||
{
|
||||
auto X_ = TensorUtils::to_matrix(X);
|
||||
return predict_proba(X_);
|
||||
}
|
||||
torch::Tensor Classifier::predict(torch::Tensor& X)
|
||||
{
|
||||
auto X_ = TensorUtils::to_matrix(X);
|
||||
auto predict = predict(X_);
|
||||
return TensorUtils::to_tensor(predict);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user