Files
BayesNet/docs/manual/_boost_a2_d_e_8cc_source.html

31 KiB

<html xmlns="http://www.w3.org/1999/xhtml" lang="en-US"> <head> <script type="text/javascript" src="jquery.js"></script> <script type="text/javascript" src="dynsections.js"></script> <script type="text/javascript" src="clipboard.js"></script> <script type="text/javascript" src="navtreedata.js"></script> <script type="text/javascript" src="navtree.js"></script> <script type="text/javascript" src="resize.js"></script> <script type="text/javascript" src="cookie.js"></script> <script type="text/javascript" src="search/searchdata.js"></script> <script type="text/javascript" src="search/search.js"></script> </head>
BayesNet 1.0.5
Bayesian Network Classifiers using libtorch from scratch
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ var searchBox = new SearchBox("searchBox", "search/",'.html'); /* @license-end */ </script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { codefold.init(0); }); /* @license-end */ </script> <script type="text/javascript" src="menudata.js"></script> <script type="text/javascript" src="menu.js"></script> <script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function() { initMenu('',true,false,'search.php','Search',true); $(function() { init_search(); }); }); /* @license-end */ </script>
<script type="text/javascript"> /* @license magnet:?xt=urn:btih:d3d9a9a6595521f9666a5e94cc830dab83b65699&dn=expat.txt MIT */ $(function(){initNavTree('_boost_a2_d_e_8cc_source.html',''); initResizable(true); }); /* @license-end */ </script>
Loading...
Searching...
No Matches
BoostA2DE.cc
1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
6
7#include <set>
8#include <functional>
9#include <limits.h>
10#include <tuple>
11#include <folding.hpp>
12#include "bayesnet/feature_selection/CFS.h"
13#include "bayesnet/feature_selection/FCBF.h"
14#include "bayesnet/feature_selection/IWSS.h"
15#include "BoostA2DE.h"
16
17namespace bayesnet {
18
19 BoostA2DE::BoostA2DE(bool predict_voting) : Boost(predict_voting)
20 {
21 }
22 std::vector<int> BoostA2DE::initializeModels()
23 {
24 torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
25 std::vector<int> featuresSelected = featureSelection(weights_);
26 if (featuresSelected.size() < 2) {
27 notes.push_back("No features selected in initialization");
28 status = ERROR;
29 return std::vector<int>();
30 }
31 for (int i = 0; i < featuresSelected.size() - 1; i++) {
32 for (int j = i + 1; j < featuresSelected.size(); j++) {
33 auto parents = { featuresSelected[i], featuresSelected[j] };
34 std::unique_ptr<Classifier> model = std::make_unique<SPnDE>(parents);
35 model->fit(dataset, features, className, states, weights_);
36 models.push_back(std::move(model));
37 significanceModels.push_back(1.0); // They will be updated later in trainModel
38 n_models++;
39 }
40 }
41 notes.push_back("Used features in initialization: " + std::to_string(featuresSelected.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);
42 return featuresSelected;
43 }
44 void BoostA2DE::trainModel(const torch::Tensor& weights)
45 {
46 //
47 // Logging setup
48 //
49 // loguru::set_thread_name("BoostA2DE");
50 // loguru::g_stderr_verbosity = loguru::Verbosity_OFF;
51 // loguru::add_file("boostA2DE.log", loguru::Truncate, loguru::Verbosity_MAX);
52
53 // Algorithm based on the adaboost algorithm for classification
54 // as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
55 fitted = true;
56 double alpha_t = 0;
57 torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);
58 bool finished = false;
59 std::vector<int> featuresUsed;
60 if (selectFeatures) {
61 featuresUsed = initializeModels();
62 auto ypred = predict(X_train);
63 std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
64 // Update significance of the models
65 for (int i = 0; i < n_models; ++i) {
66 significanceModels[i] = alpha_t;
67 }
68 if (finished) {
69 return;
70 }
71 }
72 int numItemsPack = 0; // The counter of the models inserted in the current pack
73 // Variables to control the accuracy finish condition
74 double priorAccuracy = 0.0;
75 double improvement = 1.0;
76 double convergence_threshold = 1e-4;
77 int tolerance = 0; // number of times the accuracy is lower than the convergence_threshold
78 // Step 0: Set the finish condition
79 // epsilon sub t > 0.5 => inverse the weights policy
80 // validation error is not decreasing
81 // run out of features
82 bool ascending = order_algorithm == Orders.ASC;
83 std::mt19937 g{ 173 };
84 std::vector<std::pair<int, int>> pairSelection;
85 while (!finished) {
86 // Step 1: Build ranking with mutual information
87 pairSelection = metrics.SelectKPairs(weights_, featuresUsed, ascending, 0); // Get all the pairs sorted
88 if (order_algorithm == Orders.RAND) {
89 std::shuffle(pairSelection.begin(), pairSelection.end(), g);
90 }
91 int k = bisection ? pow(2, tolerance) : 1;
92 int counter = 0; // The model counter of the current pack
93 // VLOG_SCOPE_F(1, "counter=%d k=%d featureSelection.size: %zu", counter, k, featureSelection.size());
94 while (counter++ < k && pairSelection.size() > 0) {
95 auto feature_pair = pairSelection[0];
96 pairSelection.erase(pairSelection.begin());
97 std::unique_ptr<Classifier> model;
98 model = std::make_unique<SPnDE>(std::vector<int>({ feature_pair.first, feature_pair.second }));
99 model->fit(dataset, features, className, states, weights_);
100 alpha_t = 0.0;
101 if (!block_update) {
102 auto ypred = model->predict(X_train);
103 // Step 3.1: Compute the classifier amout of say
104 std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);
105 }
106 // Step 3.4: Store classifier and its accuracy to weigh its future vote
107 numItemsPack++;
108 models.push_back(std::move(model));
109 significanceModels.push_back(alpha_t);
110 n_models++;
111 // VLOG_SCOPE_F(2, "numItemsPack: %d n_models: %d featuresUsed: %zu", numItemsPack, n_models, featuresUsed.size());
112 }
113 if (block_update) {
114 std::tie(weights_, alpha_t, finished) = update_weights_block(k, y_train, weights_);
115 }
116 if (convergence && !finished) {
117 auto y_val_predict = predict(X_test);
118 double accuracy = (y_val_predict == y_test).sum().item<double>() / (double)y_test.size(0);
119 if (priorAccuracy == 0) {
120 priorAccuracy = accuracy;
121 } else {
122 improvement = accuracy - priorAccuracy;
123 }
124 if (improvement < convergence_threshold) {
125 // VLOG_SCOPE_F(3, " (improvement<threshold) tolerance: %d numItemsPack: %d improvement: %f prior: %f current: %f", tolerance, numItemsPack, improvement, priorAccuracy, accuracy);
126 tolerance++;
127 } else {
128 // VLOG_SCOPE_F(3, "* (improvement>=threshold) Reset. tolerance: %d numItemsPack: %d improvement: %f prior: %f current: %f", tolerance, numItemsPack, improvement, priorAccuracy, accuracy);
129 tolerance = 0; // Reset the counter if the model performs better
130 numItemsPack = 0;
131 }
132 if (convergence_best) {
133 // Keep the best accuracy until now as the prior accuracy
134 priorAccuracy = std::max(accuracy, priorAccuracy);
135 } else {
136 // Keep the last accuray obtained as the prior accuracy
137 priorAccuracy = accuracy;
138 }
139 }
140 // VLOG_SCOPE_F(1, "tolerance: %d featuresUsed.size: %zu features.size: %zu", tolerance, featuresUsed.size(), features.size());
141 finished = finished || tolerance > maxTolerance || pairSelection.size() == 0;
142 }
143 if (tolerance > maxTolerance) {
144 if (numItemsPack < n_models) {
145 notes.push_back("Convergence threshold reached & " + std::to_string(numItemsPack) + " models eliminated");
146 // VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated of %d", numItemsPack, n_models);
147 for (int i = 0; i < numItemsPack; ++i) {
148 significanceModels.pop_back();
149 models.pop_back();
150 n_models--;
151 }
152 } else {
153 notes.push_back("Convergence threshold reached & 0 models eliminated");
154 // VLOG_SCOPE_F(4, "Convergence threshold reached & 0 models eliminated n_models=%d numItemsPack=%d", n_models, numItemsPack);
155 }
156 }
157 if (pairSelection.size() > 0) {
158 notes.push_back("Pairs not used in train: " + std::to_string(pairSelection.size()));
159 status = WARNING;
160 }
161 notes.push_back("Number of models: " + std::to_string(n_models));
162 }
163 std::vector<std::string> BoostA2DE::graph(const std::string& title) const
164 {
165 return Ensemble::graph(title);
166 }
167}
</html>