Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "AODE.h"
8 :
9 : namespace bayesnet {
10 76 : AODE::AODE(bool predict_voting) : Ensemble(predict_voting)
11 : {
12 152 : validHyperparameters = { "predict_voting" };
13 :
14 228 : }
15 4 : void AODE::setHyperparameters(const nlohmann::json& hyperparameters_)
16 : {
17 4 : auto hyperparameters = hyperparameters_;
18 4 : if (hyperparameters.contains("predict_voting")) {
19 4 : predict_voting = hyperparameters["predict_voting"];
20 4 : hyperparameters.erase("predict_voting");
21 : }
22 4 : Classifier::setHyperparameters(hyperparameters);
23 4 : }
24 24 : void AODE::buildModel(const torch::Tensor& weights)
25 : {
26 24 : models.clear();
27 24 : significanceModels.clear();
28 188 : for (int i = 0; i < features.size(); ++i) {
29 164 : models.push_back(std::make_unique<SPODE>(i));
30 : }
31 24 : n_models = models.size();
32 24 : significanceModels = std::vector<double>(n_models, 1.0);
33 24 : }
34 4 : std::vector<std::string> AODE::graph(const std::string& title) const
35 : {
36 4 : return Ensemble::graph(title);
37 : }
38 : }
|