Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef SPODELD_H
8 : #define SPODELD_H
9 : #include "SPODE.h"
10 : #include "Proposal.h"
11 :
12 : namespace bayesnet {
13 : class SPODELd : public SPODE, public Proposal {
14 : public:
15 : explicit SPODELd(int root);
16 480 : virtual ~SPODELd() = default;
17 : SPODELd& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states) override;
18 : SPODELd& fit(torch::Tensor& dataset, const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states) override;
19 : SPODELd& commonFit(const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states);
20 : std::vector<std::string> graph(const std::string& name = "SPODE") const override;
21 : torch::Tensor predict(torch::Tensor& X) override;
22 : static inline std::string version() { return "0.0.1"; };
23 : };
24 : }
25 : #endif // !SPODELD_H
|