BayesNet 1.0.5
Bayesian Network Classifiers using libtorch from scratch
Loading...
Searching...
No Matches
Boost.h
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#ifndef BOOST_H
8#define BOOST_H
9#include <string>
10#include <tuple>
11#include <vector>
12#include <nlohmann/json.hpp>
13#include <torch/torch.h>
14#include "Ensemble.h"
15#include "bayesnet/feature_selection/FeatureSelect.h"
16namespace bayesnet {
17 const struct {
18 std::string CFS = "CFS";
19 std::string FCBF = "FCBF";
20 std::string IWSS = "IWSS";
21 }SelectFeatures;
22 const struct {
23 std::string ASC = "asc";
24 std::string DESC = "desc";
25 std::string RAND = "rand";
26 }Orders;
27 class Boost : public Ensemble {
28 public:
29 explicit Boost(bool predict_voting = false);
30 virtual ~Boost() = default;
31 void setHyperparameters(const nlohmann::json& hyperparameters_) override;
32 protected:
33 std::vector<int> featureSelection(torch::Tensor& weights_);
34 void buildModel(const torch::Tensor& weights) override;
35 std::tuple<torch::Tensor&, double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights);
36 std::tuple<torch::Tensor&, double, bool> update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights);
37 torch::Tensor X_train, y_train, X_test, y_test;
38 // Hyperparameters
39 bool bisection = true; // if true, use bisection stratety to add k models at once to the ensemble
40 int maxTolerance = 3;
41 std::string order_algorithm; // order to process the KBest features asc, desc, rand
42 bool convergence = true; //if true, stop when the model does not improve
43 bool convergence_best = false; // wether to keep the best accuracy to the moment or the last accuracy as prior accuracy
44 bool selectFeatures = false; // if true, use feature selection
45 std::string select_features_algorithm = Orders.DESC; // Selected feature selection algorithm
46 FeatureSelect* featureSelector = nullptr;
47 double threshold = -1;
48 bool block_update = false;
49
50 };
51}
52#endif