Fix XSpode predict

This commit is contained in:
2025-03-10 11:18:04 +01:00
parent 06621ea361
commit ca54f799ee
6 changed files with 310 additions and 23 deletions

View File

@@ -8,15 +8,6 @@
#define XSPODE_H
#include <vector>
#include <map>
#include <stdexcept>
#include <algorithm>
#include <numeric>
#include <string>
#include <cmath>
#include <limits>
#include <sstream>
#include <iostream>
#include <torch/torch.h>
#include "Classifier.h"
#include "bayesnet/utils/CountingSemaphore.h"
@@ -32,7 +23,6 @@ namespace bayesnet {
std::vector<int> predict(std::vector<std::vector<int>>& test_data);
void normalize(std::vector<double>& v) const;
std::string to_string() const;
int statesClass() const;
int getNFeatures() const;
int getNumberOfNodes() const override;
int getNumberOfEdges() const override;
@@ -41,6 +31,15 @@ namespace bayesnet {
std::vector<int>& getStates();
std::vector<std::string> graph(const std::string& title) const override { return std::vector<std::string>({title}); }
void fit(std::vector<std::vector<int>>& X, std::vector<int>& y, torch::Tensor& weights_, const Smoothing_t smoothing);
void setHyperparameters(const nlohmann::json& hyperparameters_) override;
//
// Classifier interface
//
torch::Tensor predict(torch::Tensor& X) override;
std::vector<int> predict(std::vector<std::vector<int>>& X) override;
torch::Tensor predict_proba(torch::Tensor& X) override;
std::vector<std::vector<double>> predict_proba(std::vector<std::vector<int>>& X) override;
protected:
void buildModel(const torch::Tensor& weights) override;
void trainModel(const torch::Tensor& weights, const bayesnet::Smoothing_t smoothing) override;