Tests XSpode & XBAODE

This commit is contained in:
2025-03-12 13:46:04 +01:00
parent 71b05cc1a7
commit 3bdb14bd65
12 changed files with 450 additions and 644 deletions

View File

@@ -3,14 +3,14 @@
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
#include "XSPODE.h"
#include "bayesnet/utils/TensorUtils.h"
#include <algorithm>
#include <cmath>
#include <limits>
#include <numeric>
#include <sstream>
#include <stdexcept>
#include "XSPODE.h"
#include "bayesnet/utils/TensorUtils.h"
namespace bayesnet {
@@ -35,7 +35,7 @@ namespace bayesnet {
Classifier::setHyperparameters(hyperparameters);
}
void XSpode::fit(torch::Tensor & X, torch::Tensor& y, torch::Tensor& weights_, const Smoothing_t smoothing)
void XSpode::fitx(torch::Tensor & X, torch::Tensor& y, torch::Tensor& weights_, const Smoothing_t smoothing)
{
m = X.size(1);
n = X.size(0);
@@ -390,9 +390,8 @@ namespace bayesnet {
}
int XSpode::getNumberOfEdges() const
{
return nFeatures_ * (2 * nFeatures_ - 1);
return 2 * nFeatures_ + 1;
}
std::vector<int>& XSpode::getStates() { return states_; }
// ------------------------------------------------------
// Predict overrides (classifier interface)

View File

@@ -29,7 +29,7 @@ namespace bayesnet {
int getClassNumStates() const override;
std::vector<int>& getStates();
std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({ title }); }
void fit(torch::Tensor& X, torch::Tensor& y, torch::Tensor& weights_, const Smoothing_t smoothing);
void fitx(torch::Tensor& X, torch::Tensor& y, torch::Tensor& weights_, const Smoothing_t smoothing);
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
//