Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef BOOSTAODE_H
8 : #define BOOSTAODE_H
9 : #include <map>
10 : #include "bayesnet/classifiers/SPODE.h"
11 : #include "bayesnet/feature_selection/FeatureSelect.h"
12 : #include "Ensemble.h"
13 : namespace bayesnet {
14 : const struct {
15 : std::string CFS = "CFS";
16 : std::string FCBF = "FCBF";
17 : std::string IWSS = "IWSS";
18 : }SelectFeatures;
19 : const struct {
20 : std::string ASC = "asc";
21 : std::string DESC = "desc";
22 : std::string RAND = "rand";
23 : }Orders;
24 : class BoostAODE : public Ensemble {
25 : public:
26 : explicit BoostAODE(bool predict_voting = false);
27 88 : virtual ~BoostAODE() = default;
28 : std::vector<std::string> graph(const std::string& title = "BoostAODE") const override;
29 : void setHyperparameters(const nlohmann::json& hyperparameters_) override;
30 : protected:
31 : void buildModel(const torch::Tensor& weights) override;
32 : void trainModel(const torch::Tensor& weights) override;
33 : private:
34 : std::tuple<torch::Tensor&, double, bool> update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights);
35 : std::vector<int> initializeModels();
36 : torch::Tensor X_train, y_train, X_test, y_test;
37 : // Hyperparameters
38 : bool bisection = true; // if true, use bisection stratety to add k models at once to the ensemble
39 : int maxTolerance = 3;
40 : std::string order_algorithm; // order to process the KBest features asc, desc, rand
41 : bool convergence = true; //if true, stop when the model does not improve
42 : bool convergence_best = false; // wether to keep the best accuracy to the moment or the last accuracy as prior accuracy
43 : bool selectFeatures = false; // if true, use feature selection
44 : std::string select_features_algorithm = Orders.DESC; // Selected feature selection algorithm
45 : FeatureSelect* featureSelector = nullptr;
46 : double threshold = -1;
47 : bool block_update = false;
48 : };
49 : }
50 : #endif
|