Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef TANLD_H
8 : #define TANLD_H
9 : #include "TAN.h"
10 : #include "Proposal.h"
11 :
12 : namespace bayesnet {
13 : class TANLd : public TAN, public Proposal {
14 : private:
15 : public:
16 : TANLd();
17 20 : virtual ~TANLd() = default;
18 : TANLd& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states) override;
19 : std::vector<std::string> graph(const std::string& name = "TAN") const override;
20 : torch::Tensor predict(torch::Tensor& X) override;
21 : static inline std::string version() { return "0.0.1"; };
22 : };
23 : }
24 : #endif // !TANLD_H
|