Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef KDB_H
8 : #define KDB_H
9 : #include <torch/torch.h>
10 : #include "bayesnet/utils/bayesnetUtils.h"
11 : #include "Classifier.h"
12 : namespace bayesnet {
13 : class KDB : public Classifier {
14 : private:
15 : int k;
16 : float theta;
17 : void add_m_edges(int idx, std::vector<int>& S, torch::Tensor& weights);
18 : protected:
19 : void buildModel(const torch::Tensor& weights) override;
20 : public:
21 : explicit KDB(int k, float theta = 0.03);
22 44 : virtual ~KDB() = default;
23 : void setHyperparameters(const nlohmann::json& hyperparameters_) override;
24 : std::vector<std::string> graph(const std::string& name = "KDB") const override;
25 : };
26 : }
27 : #endif
|