2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2024-05-15 18:00:44 +00:00
# include <random>
2023-08-20 18:31:23 +00:00
# include <set>
2023-10-10 16:16:43 +00:00
# include <functional>
# include <limits.h>
2024-03-05 11:10:58 +00:00
# include <tuple>
2023-10-10 16:16:43 +00:00
# include "BoostAODE.h"
2023-08-15 14:16:04 +00:00
namespace bayesnet {
2024-04-02 07:52:40 +00:00
2024-05-15 17:49:15 +00:00
BoostAODE : : BoostAODE ( bool predict_voting ) : Boost ( predict_voting )
2023-11-19 21:36:27 +00:00
{
}
2024-03-20 22:33:02 +00:00
std : : vector < int > BoostAODE : : initializeModels ( )
2023-10-10 09:52:39 +00:00
{
2023-11-08 17:45:35 +00:00
torch : : Tensor weights_ = torch : : full ( { m } , 1.0 / m , torch : : kFloat64 ) ;
2024-05-15 17:49:15 +00:00
std : : vector < int > featuresSelected = featureSelection ( weights_ ) ;
for ( const int & feature : featuresSelected ) {
2023-11-08 17:45:35 +00:00
std : : unique_ptr < Classifier > model = std : : make_unique < SPODE > ( feature ) ;
2023-10-13 11:46:22 +00:00
model - > fit ( dataset , features , className , states , weights_ ) ;
models . push_back ( std : : move ( model ) ) ;
2024-04-09 22:55:36 +00:00
significanceModels . push_back ( 1.0 ) ; // They will be updated later in trainModel
2023-10-13 11:46:22 +00:00
n_models + + ;
2023-10-10 09:52:39 +00:00
}
2024-05-15 17:49:15 +00:00
notes . push_back ( " Used features in initialization: " + std : : to_string ( featuresSelected . size ( ) ) + " of " + std : : to_string ( features . size ( ) ) + " with " + select_features_algorithm ) ;
return featuresSelected ;
2023-10-10 09:52:39 +00:00
}
2023-08-16 17:05:18 +00:00
void BoostAODE : : trainModel ( const torch : : Tensor & weights )
{
2024-04-29 22:52:09 +00:00
//
// Logging setup
//
2024-05-16 09:17:21 +00:00
// loguru::set_thread_name("BoostAODE");
// loguru::g_stderr_verbosity = loguru::Verbosity_OFF;
// loguru::add_file("boostAODE.log", loguru::Truncate, loguru::Verbosity_MAX);
2024-04-29 22:52:09 +00:00
2024-03-06 16:04:16 +00:00
// Algorithm based on the adaboost algorithm for classification
// as explained in Ensemble methods (Zhi-Hua Zhou, 2012)
2024-02-20 09:11:22 +00:00
fitted = true ;
2024-03-05 11:10:58 +00:00
double alpha_t = 0 ;
torch : : Tensor weights_ = torch : : full ( { m } , 1.0 / m , torch : : kFloat64 ) ;
2024-03-19 13:13:40 +00:00
bool finished = false ;
2024-03-20 22:33:02 +00:00
std : : vector < int > featuresUsed ;
2023-10-14 11:12:04 +00:00
if ( selectFeatures ) {
2023-10-11 09:33:29 +00:00
featuresUsed = initializeModels ( ) ;
2024-03-05 11:10:58 +00:00
auto ypred = predict ( X_train ) ;
2024-03-19 13:13:40 +00:00
std : : tie ( weights_ , alpha_t , finished ) = update_weights ( y_train , ypred , weights_ ) ;
2024-03-05 11:10:58 +00:00
// Update significance of the models
for ( int i = 0 ; i < n_models ; + + i ) {
significanceModels [ i ] = alpha_t ;
}
2024-03-19 13:13:40 +00:00
if ( finished ) {
2024-03-05 11:10:58 +00:00
return ;
}
2023-10-11 09:33:29 +00:00
}
2024-03-20 10:30:02 +00:00
int numItemsPack = 0 ; // The counter of the models inserted in the current pack
2023-09-07 09:27:35 +00:00
// Variables to control the accuracy finish condition
double priorAccuracy = 0.0 ;
2024-03-20 22:33:02 +00:00
double improvement = 1.0 ;
2024-02-19 21:58:15 +00:00
double convergence_threshold = 1e-4 ;
2024-03-19 13:13:40 +00:00
int tolerance = 0 ; // number of times the accuracy is lower than the convergence_threshold
2023-08-18 09:50:34 +00:00
// Step 0: Set the finish condition
2023-10-25 08:23:42 +00:00
// epsilon sub t > 0.5 => inverse the weights policy
2023-09-07 09:27:35 +00:00
// validation error is not decreasing
2024-03-20 10:30:02 +00:00
// run out of features
2024-03-05 10:05:11 +00:00
bool ascending = order_algorithm = = Orders . ASC ;
2024-02-26 16:07:57 +00:00
std : : mt19937 g { 173 } ;
2024-03-19 13:13:40 +00:00
while ( ! finished ) {
2023-08-18 09:50:34 +00:00
// Step 1: Build ranking with mutual information
2023-08-20 18:31:23 +00:00
auto featureSelection = metrics . SelectKBestWeighted ( weights_ , ascending , n ) ; // Get all the features sorted
2024-03-05 10:05:11 +00:00
if ( order_algorithm = = Orders . RAND ) {
2024-02-26 16:07:57 +00:00
std : : shuffle ( featureSelection . begin ( ) , featureSelection . end ( ) , g ) ;
}
2024-03-19 08:42:03 +00:00
// Remove used features
featureSelection . erase ( remove_if ( begin ( featureSelection ) , end ( featureSelection ) , [ & ] ( auto x )
2024-03-20 10:30:02 +00:00
{ return std : : find ( begin ( featuresUsed ) , end ( featuresUsed ) , x ) ! = end ( featuresUsed ) ; } ) ,
2024-03-19 08:42:03 +00:00
end ( featureSelection )
) ;
2024-04-29 22:52:09 +00:00
int k = bisection ? pow ( 2 , tolerance ) : 1 ;
2024-03-20 10:30:02 +00:00
int counter = 0 ; // The model counter of the current pack
2024-05-16 09:17:21 +00:00
// VLOG_SCOPE_F(1, "counter=%d k=%d featureSelection.size: %zu", counter, k, featureSelection.size());
2024-03-20 10:30:02 +00:00
while ( counter + + < k & & featureSelection . size ( ) > 0 ) {
2024-03-19 13:13:40 +00:00
auto feature = featureSelection [ 0 ] ;
featureSelection . erase ( featureSelection . begin ( ) ) ;
std : : unique_ptr < Classifier > model ;
model = std : : make_unique < SPODE > ( feature ) ;
model - > fit ( dataset , features , className , states , weights_ ) ;
2024-04-09 22:55:36 +00:00
alpha_t = 0.0 ;
if ( ! block_update ) {
auto ypred = model - > predict ( X_train ) ;
// Step 3.1: Compute the classifier amout of say
std : : tie ( weights_ , alpha_t , finished ) = update_weights ( y_train , ypred , weights_ ) ;
2024-03-19 13:13:40 +00:00
}
// Step 3.4: Store classifier and its accuracy to weigh its future vote
2024-03-20 10:30:02 +00:00
numItemsPack + + ;
2024-03-20 22:33:02 +00:00
featuresUsed . push_back ( feature ) ;
2024-03-19 13:13:40 +00:00
models . push_back ( std : : move ( model ) ) ;
significanceModels . push_back ( alpha_t ) ;
n_models + + ;
2024-05-16 09:17:21 +00:00
// VLOG_SCOPE_F(2, "numItemsPack: %d n_models: %d featuresUsed: %zu", numItemsPack, n_models, featuresUsed.size());
2024-02-20 09:11:22 +00:00
}
2024-04-09 22:55:36 +00:00
if ( block_update ) {
std : : tie ( weights_ , alpha_t , finished ) = update_weights_block ( k , y_train , weights_ ) ;
}
2024-03-20 10:30:02 +00:00
if ( convergence & & ! finished ) {
2023-09-10 17:50:36 +00:00
auto y_val_predict = predict ( X_test ) ;
double accuracy = ( y_val_predict = = y_test ) . sum ( ) . item < double > ( ) / ( double ) y_test . size ( 0 ) ;
if ( priorAccuracy = = 0 ) {
priorAccuracy = accuracy ;
} else {
2024-03-20 22:33:02 +00:00
improvement = accuracy - priorAccuracy ;
2023-09-10 17:50:36 +00:00
}
2024-03-20 22:33:02 +00:00
if ( improvement < convergence_threshold ) {
2024-05-16 09:17:21 +00:00
// VLOG_SCOPE_F(3, " (improvement<threshold) tolerance: %d numItemsPack: %d improvement: %f prior: %f current: %f", tolerance, numItemsPack, improvement, priorAccuracy, accuracy);
2024-03-19 13:13:40 +00:00
tolerance + + ;
2024-03-11 20:30:01 +00:00
} else {
2024-05-16 09:17:21 +00:00
// VLOG_SCOPE_F(3, "* (improvement>=threshold) Reset. tolerance: %d numItemsPack: %d improvement: %f prior: %f current: %f", tolerance, numItemsPack, improvement, priorAccuracy, accuracy);
2024-03-19 13:13:40 +00:00
tolerance = 0 ; // Reset the counter if the model performs better
2024-03-20 10:30:02 +00:00
numItemsPack = 0 ;
2023-09-10 17:50:36 +00:00
}
2024-04-29 22:52:09 +00:00
if ( convergence_best ) {
// Keep the best accuracy until now as the prior accuracy
priorAccuracy = std : : max ( accuracy , priorAccuracy ) ;
} else {
// Keep the last accuray obtained as the prior accuracy
priorAccuracy = accuracy ;
}
2023-09-07 09:27:35 +00:00
}
2024-05-16 09:17:21 +00:00
// VLOG_SCOPE_F(1, "tolerance: %d featuresUsed.size: %zu features.size: %zu", tolerance, featuresUsed.size(), features.size());
2024-03-20 10:30:02 +00:00
finished = finished | | tolerance > maxTolerance | | featuresUsed . size ( ) = = features . size ( ) ;
2023-08-20 18:31:23 +00:00
}
2024-03-20 10:30:02 +00:00
if ( tolerance > maxTolerance ) {
if ( numItemsPack < n_models ) {
notes . push_back ( " Convergence threshold reached & " + std : : to_string ( numItemsPack ) + " models eliminated " ) ;
2024-05-16 09:17:21 +00:00
// VLOG_SCOPE_F(4, "Convergence threshold reached & %d models eliminated of %d", numItemsPack, n_models);
2024-03-20 10:30:02 +00:00
for ( int i = 0 ; i < numItemsPack ; + + i ) {
significanceModels . pop_back ( ) ;
models . pop_back ( ) ;
n_models - - ;
}
} else {
notes . push_back ( " Convergence threshold reached & 0 models eliminated " ) ;
2024-05-16 09:17:21 +00:00
// VLOG_SCOPE_F(4, "Convergence threshold reached & 0 models eliminated n_models=%d numItemsPack=%d", n_models, numItemsPack);
2024-03-19 13:13:40 +00:00
}
2024-03-11 21:33:50 +00:00
}
2023-08-20 18:31:23 +00:00
if ( featuresUsed . size ( ) ! = features . size ( ) ) {
2024-02-08 17:01:09 +00:00
notes . push_back ( " Used features in train: " + std : : to_string ( featuresUsed . size ( ) ) + " of " + std : : to_string ( features . size ( ) ) ) ;
2023-09-05 11:39:43 +00:00
status = WARNING ;
2023-08-16 17:05:18 +00:00
}
2024-02-12 09:58:20 +00:00
notes . push_back ( " Number of models: " + std : : to_string ( n_models ) ) ;
2023-08-16 17:05:18 +00:00
}
2023-11-08 17:45:35 +00:00
std : : vector < std : : string > BoostAODE : : graph ( const std : : string & title ) const
2023-08-15 14:16:04 +00:00
{
return Ensemble : : graph ( title ) ;
}
}