Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef KDBLD_H
8 : #define KDBLD_H
9 : #include "Proposal.h"
10 : #include "KDB.h"
11 :
12 : namespace bayesnet {
13 : class KDBLd : public KDB, public Proposal {
14 : private:
15 : public:
16 : explicit KDBLd(int k);
17 10 : virtual ~KDBLd() = default;
18 : KDBLd& fit(torch::Tensor& X, torch::Tensor& y, const std::vector<std::string>& features, const std::string& className, map<std::string, std::vector<int>>& states) override;
19 : std::vector<std::string> graph(const std::string& name = "KDB") const override;
20 : torch::Tensor predict(torch::Tensor& X) override;
21 : static inline std::string version() { return "0.0.1"; };
22 : };
23 : }
24 : #endif // !KDBLD_H
|