Tests XSpode & XBAODE
This commit is contained in:
@@ -3,14 +3,14 @@
|
||||
// SPDX-FileType: SOURCE
|
||||
// SPDX-License-Identifier: MIT
|
||||
// ***************************************************************
|
||||
#include "XSPODE.h"
|
||||
#include "bayesnet/utils/TensorUtils.h"
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include "XSPODE.h"
|
||||
#include "bayesnet/utils/TensorUtils.h"
|
||||
|
||||
namespace bayesnet {
|
||||
|
||||
@@ -35,7 +35,7 @@ namespace bayesnet {
|
||||
Classifier::setHyperparameters(hyperparameters);
|
||||
}
|
||||
|
||||
void XSpode::fit(torch::Tensor & X, torch::Tensor& y, torch::Tensor& weights_, const Smoothing_t smoothing)
|
||||
void XSpode::fitx(torch::Tensor & X, torch::Tensor& y, torch::Tensor& weights_, const Smoothing_t smoothing)
|
||||
{
|
||||
m = X.size(1);
|
||||
n = X.size(0);
|
||||
@@ -390,9 +390,8 @@ namespace bayesnet {
|
||||
}
|
||||
int XSpode::getNumberOfEdges() const
|
||||
{
|
||||
return nFeatures_ * (2 * nFeatures_ - 1);
|
||||
return 2 * nFeatures_ + 1;
|
||||
}
|
||||
std::vector<int>& XSpode::getStates() { return states_; }
|
||||
|
||||
// ------------------------------------------------------
|
||||
// Predict overrides (classifier interface)
|
||||
|
@@ -29,7 +29,7 @@ namespace bayesnet {
|
||||
int getClassNumStates() const override;
|
||||
std::vector<int>& getStates();
|
||||
std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({ title }); }
|
||||
void fit(torch::Tensor& X, torch::Tensor& y, torch::Tensor& weights_, const Smoothing_t smoothing);
|
||||
void fitx(torch::Tensor& X, torch::Tensor& y, torch::Tensor& weights_, const Smoothing_t smoothing);
|
||||
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
|
||||
|
||||
//
|
||||
|
Reference in New Issue
Block a user