Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "AODELd.h"
8 :
9 : namespace bayesnet {
10 68 : AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className)
11 : {
12 68 : }
13 20 : AODELd& AODELd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_)
14 : {
15 20 : checkInput(X_, y_);
16 20 : features = features_;
17 20 : className = className_;
18 20 : Xf = X_;
19 20 : y = y_;
20 : // Fills std::vectors Xv & yv with the data from tensors X_ (discretized) & y
21 20 : states = fit_local_discretization(y);
22 : // We have discretized the input data
23 : // 1st we need to fit the model to build the normal TAN structure, TAN::fit initializes the base Bayesian network
24 20 : Ensemble::fit(dataset, features, className, states);
25 20 : return *this;
26 :
27 : }
28 20 : void AODELd::buildModel(const torch::Tensor& weights)
29 : {
30 20 : models.clear();
31 168 : for (int i = 0; i < features.size(); ++i) {
32 148 : models.push_back(std::make_unique<SPODELd>(i));
33 : }
34 20 : n_models = models.size();
35 20 : significanceModels = std::vector<double>(n_models, 1.0);
36 20 : }
37 20 : void AODELd::trainModel(const torch::Tensor& weights)
38 : {
39 168 : for (const auto& model : models) {
40 148 : model->fit(Xf, y, features, className, states);
41 : }
42 20 : }
43 4 : std::vector<std::string> AODELd::graph(const std::string& name) const
44 : {
45 4 : return Ensemble::graph(name);
46 : }
47 : }
|