Fix XSpode

This commit is contained in:
2025-03-10 14:23:47 +01:00
parent d1b235261e
commit 86cccb6c7b
5 changed files with 13 additions and 20 deletions

View File

@@ -28,7 +28,7 @@ namespace bayesnet {
int getNumberOfStates() const override;
int getClassNumStates() const override;
std::vector<int>& getStates();
std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({title}); }
std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({ title }); }
void fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing);
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
@@ -38,7 +38,6 @@ namespace bayesnet {
torch::Tensor predict(torch::Tensor& X) override;
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
torch::Tensor predict_proba(torch::Tensor& X) override;
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
protected:
void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;