Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "AODE.h"
8 :
9 : namespace bayesnet {
10 38 : AODE::AODE(bool predict_voting) : Ensemble(predict_voting)
11 : {
12 76 : validHyperparameters = { "predict_voting" };
13 :
14 114 : }
15 2 : void AODE::setHyperparameters(const nlohmann::json& hyperparameters_)
16 : {
17 2 : auto hyperparameters = hyperparameters_;
18 2 : if (hyperparameters.contains("predict_voting")) {
19 2 : predict_voting = hyperparameters["predict_voting"];
20 2 : hyperparameters.erase("predict_voting");
21 : }
22 2 : Classifier::setHyperparameters(hyperparameters);
23 2 : }
24 12 : void AODE::buildModel(const torch::Tensor& weights)
25 : {
26 12 : models.clear();
27 12 : significanceModels.clear();
28 94 : for (int i = 0; i < features.size(); ++i) {
29 82 : models.push_back(std::make_unique<SPODE>(i));
30 : }
31 12 : n_models = models.size();
32 12 : significanceModels = std::vector<double>(n_models, 1.0);
33 12 : }
34 2 : std::vector<std::string> AODE::graph(const std::string& title) const
35 : {
36 2 : return Ensemble::graph(title);
37 : }
38 : }
|