9.0 KiB
9.0 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #ifndef BOOSTAODE_H 8 : #define BOOSTAODE_H 9 : #include <map> 10 : #include "bayesnet/classifiers/SPODE.h" 11 : #include "bayesnet/feature_selection/FeatureSelect.h" 12 : #include "Ensemble.h" 13 : namespace bayesnet { 14 : const struct { 15 : std::string CFS = "CFS"; 16 : std::string FCBF = "FCBF"; 17 : std::string IWSS = "IWSS"; 18 : }SelectFeatures; 19 : const struct { 20 : std::string ASC = "asc"; 21 : std::string DESC = "desc"; 22 : std::string RAND = "rand"; 23 : }Orders; 24 : class BoostAODE : public Ensemble { 25 : public: 26 : explicit BoostAODE(bool predict_voting = false); 27 44 : virtual ~BoostAODE() = default; 28 : std::vector<std::string> graph(const std::string& title = "BoostAODE") const override; 29 : void setHyperparameters(const nlohmann::json& hyperparameters_) override; 30 : protected: 31 : void buildModel(const torch::Tensor& weights) override; 32 : void trainModel(const torch::Tensor& weights) override; 33 : private: 34 : std::tuple<torch::Tensor&, double, bool> update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights); 35 : std::vector<int> initializeModels(); 36 : torch::Tensor X_train, y_train, X_test, y_test; 37 : // Hyperparameters 38 : bool bisection = true; // if true, use bisection stratety to add k models at once to the ensemble 39 : int maxTolerance = 3; 40 : std::string order_algorithm; // order to process the KBest features asc, desc, rand 41 : bool convergence = true; //if true, stop when the model does not improve 42 : bool convergence_best = false; // wether to keep the best accuracy to the moment or the last accuracy as prior accuracy 43 : bool selectFeatures = false; // if true, use feature selection 44 : std::string select_features_algorithm = Orders.DESC; // Selected feature selection algorithm 45 : FeatureSelect* featureSelector = nullptr; 46 : double threshold = -1; 47 : bool block_update = false; 48 : }; 49 : } 50 : #endif |
![]() |
Generated by: LCOV version 2.0-1 |
</html>