Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "A2DE.h"
8 :
9 : namespace bayesnet {
10 12 : A2DE::A2DE(bool predict_voting) : Ensemble(predict_voting)
11 : {
12 24 : validHyperparameters = { "predict_voting" };
13 36 : }
14 8 : void A2DE::setHyperparameters(const nlohmann::json& hyperparameters_)
15 : {
16 8 : auto hyperparameters = hyperparameters_;
17 8 : if (hyperparameters.contains("predict_voting")) {
18 8 : predict_voting = hyperparameters["predict_voting"];
19 8 : hyperparameters.erase("predict_voting");
20 : }
21 8 : Classifier::setHyperparameters(hyperparameters);
22 8 : }
23 16 : void A2DE::buildModel(const torch::Tensor& weights)
24 : {
25 16 : models.clear();
26 16 : significanceModels.clear();
27 124 : for (int i = 0; i < features.size() - 1; ++i) {
28 564 : for (int j = i + 1; j < features.size(); ++j) {
29 456 : auto model = std::make_unique<SPnDE>(std::vector<int>({ i, j }));
30 456 : models.push_back(std::move(model));
31 456 : }
32 : }
33 16 : n_models = static_cast<unsigned>(models.size());
34 16 : significanceModels = std::vector<double>(n_models, 1.0);
35 16 : }
36 4 : std::vector<std::string> A2DE::graph(const std::string& title) const
37 : {
38 4 : return Ensemble::graph(title);
39 : }
40 : }
|