Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef AODELD_H
8 : #define AODELD_H
9 : #include "bayesnet/classifiers/Proposal.h"
10 : #include "bayesnet/classifiers/SPODELd.h"
11 : #include "Ensemble.h"
12 :
13 : namespace bayesnet {
14 : class AODELd : public Ensemble, public Proposal {
15 : public:
16 : AODELd(bool predict_voting = true);
17 5 : virtual ~AODELd() = default;
18 : AODELd& fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_) override;
19 : std::vector<std::string> graph(const std::string& name = "AODELd") const override;
20 : protected:
21 : void trainModel(const torch::Tensor& weights) override;
22 : void buildModel(const torch::Tensor& weights) override;
23 : };
24 : }
25 : #endif // !AODELD_H
|