8.9 KiB
8.9 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #include "AODELd.h" 8 : 9 : namespace bayesnet { 10 187 : AODELd::AODELd(bool predict_voting) : Ensemble(predict_voting), Proposal(dataset, features, className) 11 : { 12 187 : } 13 55 : AODELd& AODELd::fit(torch::Tensor& X_, torch::Tensor& y_, const std::vector<std::string>& features_, const std::string& className_, map<std::string, std::vector<int>>& states_) 14 : { 15 55 : checkInput(X_, y_); 16 55 : features = features_; 17 55 : className = className_; 18 55 : Xf = X_; 19 55 : y = y_; 20 : // Fills std::vectors Xv & yv with the data from tensors X_ (discretized) & y 21 55 : states = fit_local_discretization(y); 22 : // We have discretized the input data 23 : // 1st we need to fit the model to build the normal TAN structure, TAN::fit initializes the base Bayesian network 24 55 : Ensemble::fit(dataset, features, className, states); 25 55 : return *this; 26 : 27 : } 28 55 : void AODELd::buildModel(const torch::Tensor& weights) 29 : { 30 55 : models.clear(); 31 462 : for (int i = 0; i < features.size(); ++i) { 32 407 : models.push_back(std::make_unique<SPODELd>(i)); 33 : } 34 55 : n_models = models.size(); 35 55 : significanceModels = std::vector<double>(n_models, 1.0); 36 55 : } 37 55 : void AODELd::trainModel(const torch::Tensor& weights) 38 : { 39 462 : for (const auto& model : models) { 40 407 : model->fit(Xf, y, features, className, states); 41 : } 42 55 : } 43 11 : std::vector<std::string> AODELd::graph(const std::string& name) const 44 : { 45 11 : return Ensemble::graph(name); 46 : } 47 : } |
![]() |
Generated by: LCOV version 2.0-1 |
</html>