2024-04-21 14:44:35 +00:00
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN">
< html lang = "en" >
< head >
< meta http-equiv = "Content-Type" content = "text/html; charset=UTF-8" >
< title > LCOV - coverage.info - bayesnet/ensembles/BoostAODE.cc< / title >
< link rel = "stylesheet" type = "text/css" href = "../../gcov.css" >
< / head >
< body >
< table width = "100%" border = 0 cellspacing = 0 cellpadding = 0 >
< tr > < td class = "title" > LCOV - code coverage report< / td > < / tr >
< tr > < td class = "ruler" > < img src = "../../glass.png" width = 3 height = 3 alt = "" > < / td > < / tr >
< tr >
< td width = "100%" >
< table cellpadding = 1 border = 0 width = "100%" >
< tr >
< td width = "10%" class = "headerItem" > Current view:< / td >
< td width = "10%" class = "headerValue" > < a href = "../../index.html" > top level< / a > - < a href = "index.html" > bayesnet/ensembles< / a > - BoostAODE.cc< span style = "font-size: 80%;" > (source / < a href = "BoostAODE.cc.func-c.html" > functions< / a > )< / span > < / td >
< td width = "5%" > < / td >
< td width = "5%" > < / td >
< td width = "5%" class = "headerCovTableHead" > Coverage< / td >
< td width = "5%" class = "headerCovTableHead" title = "Covered + Uncovered code" > Total< / td >
< td width = "5%" class = "headerCovTableHead" title = "Exercised code only" > Hit< / td >
< / tr >
< tr >
< td class = "headerItem" > Test:< / td >
< td class = "headerValue" > coverage.info< / td >
< td > < / td >
< td class = "headerItem" > Lines:< / td >
2024-04-29 22:52:09 +00:00
< td class = "headerCovTableEntryHi" > 98.3 %< / td >
< td class = "headerCovTableEntry" > 237< / td >
< td class = "headerCovTableEntry" > 233< / td >
2024-04-21 14:44:35 +00:00
< / tr >
< tr >
< td class = "headerItem" > Test Date:< / td >
2024-04-30 12:00:24 +00:00
< td class = "headerValue" > 2024-04-30 13:59:18< / td >
2024-04-21 14:44:35 +00:00
< td > < / td >
< td class = "headerItem" > Functions:< / td >
< td class = "headerCovTableEntryHi" > 100.0 %< / td >
< td class = "headerCovTableEntry" > 9< / td >
< td class = "headerCovTableEntry" > 9< / td >
< / tr >
< tr > < td > < img src = "../../glass.png" width = 3 height = 3 alt = "" > < / td > < / tr >
< / table >
< / td >
< / tr >
< tr > < td class = "ruler" > < img src = "../../glass.png" width = 3 height = 3 alt = "" > < / td > < / tr >
< / table >
< table cellpadding = 0 cellspacing = 0 border = 0 >
< tr >
< td > < br > < / td >
< / tr >
< tr >
< td >
< pre class = "sourceHeading" > Line data Source code< / pre >
< pre class = "source" >
< span id = "L1" > < span class = "lineNum" > 1< / span > : // ***************************************************************< / span >
< span id = "L2" > < span class = "lineNum" > 2< / span > : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez< / span >
< span id = "L3" > < span class = "lineNum" > 3< / span > : // SPDX-FileType: SOURCE< / span >
< span id = "L4" > < span class = "lineNum" > 4< / span > : // SPDX-License-Identifier: MIT< / span >
< span id = "L5" > < span class = "lineNum" > 5< / span > : // ***************************************************************< / span >
< span id = "L6" > < span class = "lineNum" > 6< / span > : < / span >
< span id = "L7" > < span class = "lineNum" > 7< / span > : #include < set> < / span >
< span id = "L8" > < span class = "lineNum" > 8< / span > : #include < functional> < / span >
< span id = "L9" > < span class = "lineNum" > 9< / span > : #include < limits.h> < / span >
< span id = "L10" > < span class = "lineNum" > 10< / span > : #include < tuple> < / span >
< span id = "L11" > < span class = "lineNum" > 11< / span > : #include < folding.hpp> < / span >
< span id = "L12" > < span class = "lineNum" > 12< / span > : #include " bayesnet/feature_selection/CFS.h" < / span >
< span id = "L13" > < span class = "lineNum" > 13< / span > : #include " bayesnet/feature_selection/FCBF.h" < / span >
< span id = "L14" > < span class = "lineNum" > 14< / span > : #include " bayesnet/feature_selection/IWSS.h" < / span >
< span id = "L15" > < span class = "lineNum" > 15< / span > : #include " BoostAODE.h" < / span >
2024-04-29 22:52:09 +00:00
< span id = "L16" > < span class = "lineNum" > 16< / span > : #include " lib/log/loguru.cpp" < / span >
< span id = "L17" > < span class = "lineNum" > 17< / span > : < / span >
< span id = "L18" > < span class = "lineNum" > 18< / span > : namespace bayesnet {< / span >
< span id = "L19" > < span class = "lineNum" > 19< / span > : < / span >
2024-04-30 12:00:24 +00:00
< span id = "L20" > < span class = "lineNum" > 20< / span > < span class = "tlaGNC tlaBgGNC" > 252 : BoostAODE::BoostAODE(bool predict_voting) : Ensemble(predict_voting)< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L21" > < span class = "lineNum" > 21< / span > : {< / span >
2024-04-30 12:00:24 +00:00
< span id = "L22" > < span class = "lineNum" > 22< / span > < span class = "tlaGNC" > 2772 : validHyperparameters = {< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L23" > < span class = "lineNum" > 23< / span > : " maxModels" , " bisection" , " order" , " convergence" , " convergence_best" , " threshold" ,< / span >
< span id = "L24" > < span class = "lineNum" > 24< / span > : " select_features" , " maxTolerance" , " predict_voting" , " block_update" < / span >
2024-04-30 12:00:24 +00:00
< span id = "L25" > < span class = "lineNum" > 25< / span > < span class = "tlaGNC" > 2772 : };< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L26" > < span class = "lineNum" > 26< / span > : < / span >
2024-04-30 12:00:24 +00:00
< span id = "L27" > < span class = "lineNum" > 27< / span > < span class = "tlaGNC" > 756 : }< / span > < / span >
< span id = "L28" > < span class = "lineNum" > 28< / span > < span class = "tlaGNC" > 138 : void BoostAODE::buildModel(const torch::Tensor& weights)< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L29" > < span class = "lineNum" > 29< / span > : {< / span >
< span id = "L30" > < span class = "lineNum" > 30< / span > : // Models shall be built in trainModel< / span >
2024-04-30 12:00:24 +00:00
< span id = "L31" > < span class = "lineNum" > 31< / span > < span class = "tlaGNC" > 138 : models.clear();< / span > < / span >
< span id = "L32" > < span class = "lineNum" > 32< / span > < span class = "tlaGNC" > 138 : significanceModels.clear();< / span > < / span >
< span id = "L33" > < span class = "lineNum" > 33< / span > < span class = "tlaGNC" > 138 : n_models = 0;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L34" > < span class = "lineNum" > 34< / span > : // Prepare the validation dataset< / span >
2024-04-30 12:00:24 +00:00
< span id = "L35" > < span class = "lineNum" > 35< / span > < span class = "tlaGNC" > 414 : auto y_ = dataset.index({ -1, " ..." });< / span > < / span >
< span id = "L36" > < span class = "lineNum" > 36< / span > < span class = "tlaGNC" > 138 : if (convergence) {< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L37" > < span class = "lineNum" > 37< / span > : // Prepare train & validation sets from train data< / span >
2024-04-30 12:00:24 +00:00
< span id = "L38" > < span class = "lineNum" > 38< / span > < span class = "tlaGNC" > 114 : auto fold = folding::StratifiedKFold(5, y_, 271);< / span > < / span >
< span id = "L39" > < span class = "lineNum" > 39< / span > < span class = "tlaGNC" > 114 : auto [train, test] = fold.getFold(0);< / span > < / span >
< span id = "L40" > < span class = "lineNum" > 40< / span > < span class = "tlaGNC" > 114 : auto train_t = torch::tensor(train);< / span > < / span >
< span id = "L41" > < span class = "lineNum" > 41< / span > < span class = "tlaGNC" > 114 : auto test_t = torch::tensor(test);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L42" > < span class = "lineNum" > 42< / span > : // Get train and validation sets< / span >
2024-04-30 12:00:24 +00:00
< span id = "L43" > < span class = "lineNum" > 43< / span > < span class = "tlaGNC" > 570 : X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), train_t });< / span > < / span >
< span id = "L44" > < span class = "lineNum" > 44< / span > < span class = "tlaGNC" > 342 : y_train = dataset.index({ -1, train_t });< / span > < / span >
< span id = "L45" > < span class = "lineNum" > 45< / span > < span class = "tlaGNC" > 570 : X_test = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), test_t });< / span > < / span >
< span id = "L46" > < span class = "lineNum" > 46< / span > < span class = "tlaGNC" > 342 : y_test = dataset.index({ -1, test_t });< / span > < / span >
< span id = "L47" > < span class = "lineNum" > 47< / span > < span class = "tlaGNC" > 114 : dataset = X_train;< / span > < / span >
< span id = "L48" > < span class = "lineNum" > 48< / span > < span class = "tlaGNC" > 114 : m = X_train.size(1);< / span > < / span >
< span id = "L49" > < span class = "lineNum" > 49< / span > < span class = "tlaGNC" > 114 : auto n_classes = states.at(className).size();< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L50" > < span class = "lineNum" > 50< / span > : // Build dataset with train data< / span >
2024-04-30 12:00:24 +00:00
< span id = "L51" > < span class = "lineNum" > 51< / span > < span class = "tlaGNC" > 114 : buildDataset(y_train);< / span > < / span >
< span id = "L52" > < span class = "lineNum" > 52< / span > < span class = "tlaGNC" > 114 : metrics = Metrics(dataset, features, className, n_classes);< / span > < / span >
< span id = "L53" > < span class = "lineNum" > 53< / span > < span class = "tlaGNC" > 114 : } else {< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L54" > < span class = "lineNum" > 54< / span > : // Use all data to train< / span >
2024-04-30 12:00:24 +00:00
< span id = "L55" > < span class = "lineNum" > 55< / span > < span class = "tlaGNC" > 96 : X_train = dataset.index({ torch::indexing::Slice(0, dataset.size(0) - 1), " ..." });< / span > < / span >
< span id = "L56" > < span class = "lineNum" > 56< / span > < span class = "tlaGNC" > 24 : y_train = y_;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L57" > < span class = "lineNum" > 57< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L58" > < span class = "lineNum" > 58< / span > < span class = "tlaGNC" > 1350 : }< / span > < / span >
< span id = "L59" > < span class = "lineNum" > 59< / span > < span class = "tlaGNC" > 132 : void BoostAODE::setHyperparameters(const nlohmann::json& hyperparameters_)< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L60" > < span class = "lineNum" > 60< / span > : {< / span >
2024-04-30 12:00:24 +00:00
< span id = "L61" > < span class = "lineNum" > 61< / span > < span class = "tlaGNC" > 132 : auto hyperparameters = hyperparameters_;< / span > < / span >
< span id = "L62" > < span class = "lineNum" > 62< / span > < span class = "tlaGNC" > 132 : if (hyperparameters.contains(" order" )) {< / span > < / span >
< span id = "L63" > < span class = "lineNum" > 63< / span > < span class = "tlaGNC" > 150 : std::vector< std::string> algos = { Orders.ASC, Orders.DESC, Orders.RAND };< / span > < / span >
< span id = "L64" > < span class = "lineNum" > 64< / span > < span class = "tlaGNC" > 30 : order_algorithm = hyperparameters[" order" ];< / span > < / span >
< span id = "L65" > < span class = "lineNum" > 65< / span > < span class = "tlaGNC" > 30 : if (std::find(algos.begin(), algos.end(), order_algorithm) == algos.end()) {< / span > < / span >
< span id = "L66" > < span class = "lineNum" > 66< / span > < span class = "tlaGNC" > 6 : throw std::invalid_argument(" Invalid order algorithm, valid values [" + Orders.ASC + " , " + Orders.DESC + " , " + Orders.RAND + " ]" );< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L67" > < span class = "lineNum" > 67< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L68" > < span class = "lineNum" > 68< / span > < span class = "tlaGNC" > 24 : hyperparameters.erase(" order" );< / span > < / span >
< span id = "L69" > < span class = "lineNum" > 69< / span > < span class = "tlaGNC" > 30 : }< / span > < / span >
< span id = "L70" > < span class = "lineNum" > 70< / span > < span class = "tlaGNC" > 126 : if (hyperparameters.contains(" convergence" )) {< / span > < / span >
< span id = "L71" > < span class = "lineNum" > 71< / span > < span class = "tlaGNC" > 54 : convergence = hyperparameters[" convergence" ];< / span > < / span >
< span id = "L72" > < span class = "lineNum" > 72< / span > < span class = "tlaGNC" > 54 : hyperparameters.erase(" convergence" );< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L73" > < span class = "lineNum" > 73< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L74" > < span class = "lineNum" > 74< / span > < span class = "tlaGNC" > 126 : if (hyperparameters.contains(" convergence_best" )) {< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L75" > < span class = "lineNum" > 75< / span > < span class = "tlaGNC" > 18 : convergence_best = hyperparameters[" convergence_best" ];< / span > < / span >
< span id = "L76" > < span class = "lineNum" > 76< / span > < span class = "tlaGNC" > 18 : hyperparameters.erase(" convergence_best" );< / span > < / span >
< span id = "L77" > < span class = "lineNum" > 77< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L78" > < span class = "lineNum" > 78< / span > < span class = "tlaGNC" > 126 : if (hyperparameters.contains(" bisection" )) {< / span > < / span >
< span id = "L79" > < span class = "lineNum" > 79< / span > < span class = "tlaGNC" > 48 : bisection = hyperparameters[" bisection" ];< / span > < / span >
< span id = "L80" > < span class = "lineNum" > 80< / span > < span class = "tlaGNC" > 48 : hyperparameters.erase(" bisection" );< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L81" > < span class = "lineNum" > 81< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L82" > < span class = "lineNum" > 82< / span > < span class = "tlaGNC" > 126 : if (hyperparameters.contains(" threshold" )) {< / span > < / span >
< span id = "L83" > < span class = "lineNum" > 83< / span > < span class = "tlaGNC" > 36 : threshold = hyperparameters[" threshold" ];< / span > < / span >
< span id = "L84" > < span class = "lineNum" > 84< / span > < span class = "tlaGNC" > 36 : hyperparameters.erase(" threshold" );< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L85" > < span class = "lineNum" > 85< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L86" > < span class = "lineNum" > 86< / span > < span class = "tlaGNC" > 126 : if (hyperparameters.contains(" maxTolerance" )) {< / span > < / span >
< span id = "L87" > < span class = "lineNum" > 87< / span > < span class = "tlaGNC" > 66 : maxTolerance = hyperparameters[" maxTolerance" ];< / span > < / span >
< span id = "L88" > < span class = "lineNum" > 88< / span > < span class = "tlaGNC" > 66 : if (maxTolerance < 1 || maxTolerance > 4)< / span > < / span >
< span id = "L89" > < span class = "lineNum" > 89< / span > < span class = "tlaGNC" > 18 : throw std::invalid_argument(" Invalid maxTolerance value, must be greater in [1, 4]" );< / span > < / span >
< span id = "L90" > < span class = "lineNum" > 90< / span > < span class = "tlaGNC" > 48 : hyperparameters.erase(" maxTolerance" );< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L91" > < span class = "lineNum" > 91< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L92" > < span class = "lineNum" > 92< / span > < span class = "tlaGNC" > 108 : if (hyperparameters.contains(" predict_voting" )) {< / span > < / span >
< span id = "L93" > < span class = "lineNum" > 93< / span > < span class = "tlaGNC" > 6 : predict_voting = hyperparameters[" predict_voting" ];< / span > < / span >
< span id = "L94" > < span class = "lineNum" > 94< / span > < span class = "tlaGNC" > 6 : hyperparameters.erase(" predict_voting" );< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L95" > < span class = "lineNum" > 95< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L96" > < span class = "lineNum" > 96< / span > < span class = "tlaGNC" > 108 : if (hyperparameters.contains(" select_features" )) {< / span > < / span >
< span id = "L97" > < span class = "lineNum" > 97< / span > < span class = "tlaGNC" > 54 : auto selectedAlgorithm = hyperparameters[" select_features" ];< / span > < / span >
< span id = "L98" > < span class = "lineNum" > 98< / span > < span class = "tlaGNC" > 270 : std::vector< std::string> algos = { SelectFeatures.IWSS, SelectFeatures.CFS, SelectFeatures.FCBF };< / span > < / span >
< span id = "L99" > < span class = "lineNum" > 99< / span > < span class = "tlaGNC" > 54 : selectFeatures = true;< / span > < / span >
< span id = "L100" > < span class = "lineNum" > 100< / span > < span class = "tlaGNC" > 54 : select_features_algorithm = selectedAlgorithm;< / span > < / span >
< span id = "L101" > < span class = "lineNum" > 101< / span > < span class = "tlaGNC" > 54 : if (std::find(algos.begin(), algos.end(), selectedAlgorithm) == algos.end()) {< / span > < / span >
< span id = "L102" > < span class = "lineNum" > 102< / span > < span class = "tlaGNC" > 6 : throw std::invalid_argument(" Invalid selectFeatures value, valid values [" + SelectFeatures.IWSS + " , " + SelectFeatures.CFS + " , " + SelectFeatures.FCBF + " ]" );< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L103" > < span class = "lineNum" > 103< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L104" > < span class = "lineNum" > 104< / span > < span class = "tlaGNC" > 48 : hyperparameters.erase(" select_features" );< / span > < / span >
< span id = "L105" > < span class = "lineNum" > 105< / span > < span class = "tlaGNC" > 60 : }< / span > < / span >
< span id = "L106" > < span class = "lineNum" > 106< / span > < span class = "tlaGNC" > 102 : if (hyperparameters.contains(" block_update" )) {< / span > < / span >
< span id = "L107" > < span class = "lineNum" > 107< / span > < span class = "tlaGNC" > 12 : block_update = hyperparameters[" block_update" ];< / span > < / span >
< span id = "L108" > < span class = "lineNum" > 108< / span > < span class = "tlaGNC" > 12 : hyperparameters.erase(" block_update" );< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L109" > < span class = "lineNum" > 109< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L110" > < span class = "lineNum" > 110< / span > < span class = "tlaGNC" > 102 : Classifier::setHyperparameters(hyperparameters);< / span > < / span >
< span id = "L111" > < span class = "lineNum" > 111< / span > < span class = "tlaGNC" > 216 : }< / span > < / span >
< span id = "L112" > < span class = "lineNum" > 112< / span > < span class = "tlaGNC" > 816 : std::tuple< torch::Tensor& , double, bool> update_weights(torch::Tensor& ytrain, torch::Tensor& ypred, torch::Tensor& weights)< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L113" > < span class = "lineNum" > 113< / span > : {< / span >
2024-04-30 12:00:24 +00:00
< span id = "L114" > < span class = "lineNum" > 114< / span > < span class = "tlaGNC" > 816 : bool terminate = false;< / span > < / span >
< span id = "L115" > < span class = "lineNum" > 115< / span > < span class = "tlaGNC" > 816 : double alpha_t = 0;< / span > < / span >
< span id = "L116" > < span class = "lineNum" > 116< / span > < span class = "tlaGNC" > 816 : auto mask_wrong = ypred != ytrain;< / span > < / span >
< span id = "L117" > < span class = "lineNum" > 117< / span > < span class = "tlaGNC" > 816 : auto mask_right = ypred == ytrain;< / span > < / span >
< span id = "L118" > < span class = "lineNum" > 118< / span > < span class = "tlaGNC" > 816 : auto masked_weights = weights * mask_wrong.to(weights.dtype());< / span > < / span >
< span id = "L119" > < span class = "lineNum" > 119< / span > < span class = "tlaGNC" > 816 : double epsilon_t = masked_weights.sum().item< double> ();< / span > < / span >
< span id = "L120" > < span class = "lineNum" > 120< / span > < span class = "tlaGNC" > 816 : if (epsilon_t > 0.5) {< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L121" > < span class = "lineNum" > 121< / span > : // Inverse the weights policy (plot ln(wt))< / span >
< span id = "L122" > < span class = "lineNum" > 122< / span > : // " In each round of AdaBoost, there is a sanity check to ensure that the current base < / span >
< span id = "L123" > < span class = "lineNum" > 123< / span > : // learner is better than random guess" (Zhi-Hua Zhou, 2012)< / span >
2024-04-30 12:00:24 +00:00
< span id = "L124" > < span class = "lineNum" > 124< / span > < span class = "tlaGNC" > 24 : terminate = true;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L125" > < span class = "lineNum" > 125< / span > : } else {< / span >
2024-04-30 12:00:24 +00:00
< span id = "L126" > < span class = "lineNum" > 126< / span > < span class = "tlaGNC" > 792 : double wt = (1 - epsilon_t) / epsilon_t;< / span > < / span >
< span id = "L127" > < span class = "lineNum" > 127< / span > < span class = "tlaGNC" > 792 : alpha_t = epsilon_t == 0 ? 1 : 0.5 * log(wt);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L128" > < span class = "lineNum" > 128< / span > : // Step 3.2: Update weights for next classifier< / span >
< span id = "L129" > < span class = "lineNum" > 129< / span > : // Step 3.2.1: Update weights of wrong samples< / span >
2024-04-30 12:00:24 +00:00
< span id = "L130" > < span class = "lineNum" > 130< / span > < span class = "tlaGNC" > 792 : weights += mask_wrong.to(weights.dtype()) * exp(alpha_t) * weights;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L131" > < span class = "lineNum" > 131< / span > : // Step 3.2.2: Update weights of right samples< / span >
2024-04-30 12:00:24 +00:00
< span id = "L132" > < span class = "lineNum" > 132< / span > < span class = "tlaGNC" > 792 : weights += mask_right.to(weights.dtype()) * exp(-alpha_t) * weights;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L133" > < span class = "lineNum" > 133< / span > : // Step 3.3: Normalise the weights< / span >
2024-04-30 12:00:24 +00:00
< span id = "L134" > < span class = "lineNum" > 134< / span > < span class = "tlaGNC" > 792 : double totalWeights = torch::sum(weights).item< double> ();< / span > < / span >
< span id = "L135" > < span class = "lineNum" > 135< / span > < span class = "tlaGNC" > 792 : weights = weights / totalWeights;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L136" > < span class = "lineNum" > 136< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L137" > < span class = "lineNum" > 137< / span > < span class = "tlaGNC" > 1632 : return { weights, alpha_t, terminate };< / span > < / span >
< span id = "L138" > < span class = "lineNum" > 138< / span > < span class = "tlaGNC" > 816 : }< / span > < / span >
< span id = "L139" > < span class = "lineNum" > 139< / span > < span class = "tlaGNC" > 42 : std::tuple< torch::Tensor& , double, bool> BoostAODE::update_weights_block(int k, torch::Tensor& ytrain, torch::Tensor& weights)< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L140" > < span class = "lineNum" > 140< / span > : {< / span >
< span id = "L141" > < span class = "lineNum" > 141< / span > : /* Update Block algorithm< / span >
< span id = "L142" > < span class = "lineNum" > 142< / span > : k = # of models in block< / span >
< span id = "L143" > < span class = "lineNum" > 143< / span > : n_models = # of models in ensemble to make predictions< / span >
< span id = "L144" > < span class = "lineNum" > 144< / span > : n_models_bak = # models saved< / span >
< span id = "L145" > < span class = "lineNum" > 145< / span > : models = vector of models to make predictions< / span >
< span id = "L146" > < span class = "lineNum" > 146< / span > : models_bak = models not used to make predictions< / span >
< span id = "L147" > < span class = "lineNum" > 147< / span > : significances_bak = backup of significances vector< / span >
< span id = "L148" > < span class = "lineNum" > 148< / span > : < / span >
< span id = "L149" > < span class = "lineNum" > 149< / span > : Case list< / span >
< span id = "L150" > < span class = "lineNum" > 150< / span > : A) k = 1, n_models = 1 => n = 0 , n_models = n + k< / span >
< span id = "L151" > < span class = "lineNum" > 151< / span > : B) k = 1, n_models = n + 1 => n_models = n + k< / span >
< span id = "L152" > < span class = "lineNum" > 152< / span > : C) k > 1, n_models = k + 1 => n= 1, n_models = n + k< / span >
< span id = "L153" > < span class = "lineNum" > 153< / span > : D) k > 1, n_models = k => n = 0, n_models = n + k< / span >
< span id = "L154" > < span class = "lineNum" > 154< / span > : E) k > 1, n_models = k + n => n_models = n + k< / span >
< span id = "L155" > < span class = "lineNum" > 155< / span > : < / span >
< span id = "L156" > < span class = "lineNum" > 156< / span > : A, D) n=0, k > 0, n_models == k< / span >
< span id = "L157" > < span class = "lineNum" > 157< / span > : 1. n_models_bak < - n_models< / span >
< span id = "L158" > < span class = "lineNum" > 158< / span > : 2. significances_bak < - significances< / span >
< span id = "L159" > < span class = "lineNum" > 159< / span > : 3. significances = vector(k, 1)< / span >
< span id = "L160" > < span class = "lineNum" > 160< / span > : 4. Don’ t move any classifiers out of models< / span >
< span id = "L161" > < span class = "lineNum" > 161< / span > : 5. n_models < - k< / span >
< span id = "L162" > < span class = "lineNum" > 162< / span > : 6. Make prediction, compute alpha, update weights< / span >
< span id = "L163" > < span class = "lineNum" > 163< / span > : 7. Don’ t restore any classifiers to models< / span >
< span id = "L164" > < span class = "lineNum" > 164< / span > : 8. significances < - significances_bak< / span >
< span id = "L165" > < span class = "lineNum" > 165< / span > : 9. Update last k significances< / span >
< span id = "L166" > < span class = "lineNum" > 166< / span > : 10. n_models < - n_models_bak< / span >
< span id = "L167" > < span class = "lineNum" > 167< / span > : < / span >
< span id = "L168" > < span class = "lineNum" > 168< / span > : B, C, E) n > 0, k > 0, n_models == n + k< / span >
< span id = "L169" > < span class = "lineNum" > 169< / span > : 1. n_models_bak < - n_models< / span >
< span id = "L170" > < span class = "lineNum" > 170< / span > : 2. significances_bak < - significances< / span >
< span id = "L171" > < span class = "lineNum" > 171< / span > : 3. significances = vector(k, 1)< / span >
< span id = "L172" > < span class = "lineNum" > 172< / span > : 4. Move first n classifiers to models_bak< / span >
< span id = "L173" > < span class = "lineNum" > 173< / span > : 5. n_models < - k< / span >
< span id = "L174" > < span class = "lineNum" > 174< / span > : 6. Make prediction, compute alpha, update weights< / span >
< span id = "L175" > < span class = "lineNum" > 175< / span > : 7. Insert classifiers in models_bak to be the first n models< / span >
< span id = "L176" > < span class = "lineNum" > 176< / span > : 8. significances < - significances_bak< / span >
< span id = "L177" > < span class = "lineNum" > 177< / span > : 9. Update last k significances< / span >
< span id = "L178" > < span class = "lineNum" > 178< / span > : 10. n_models < - n_models_bak< / span >
< span id = "L179" > < span class = "lineNum" > 179< / span > : */< / span >
< span id = "L180" > < span class = "lineNum" > 180< / span > : //< / span >
< span id = "L181" > < span class = "lineNum" > 181< / span > : // Make predict with only the last k models< / span >
< span id = "L182" > < span class = "lineNum" > 182< / span > : //< / span >
2024-04-30 12:00:24 +00:00
< span id = "L183" > < span class = "lineNum" > 183< / span > < span class = "tlaGNC" > 42 : std::unique_ptr< Classifier> model;< / span > < / span >
< span id = "L184" > < span class = "lineNum" > 184< / span > < span class = "tlaGNC" > 42 : std::vector< std::unique_ptr< Classifier> > models_bak;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L185" > < span class = "lineNum" > 185< / span > : // 1. n_models_bak < - n_models 2. significances_bak < - significances< / span >
2024-04-30 12:00:24 +00:00
< span id = "L186" > < span class = "lineNum" > 186< / span > < span class = "tlaGNC" > 42 : auto significance_bak = significanceModels;< / span > < / span >
< span id = "L187" > < span class = "lineNum" > 187< / span > < span class = "tlaGNC" > 42 : auto n_models_bak = n_models;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L188" > < span class = "lineNum" > 188< / span > : // 3. significances = vector(k, 1)< / span >
2024-04-30 12:00:24 +00:00
< span id = "L189" > < span class = "lineNum" > 189< / span > < span class = "tlaGNC" > 42 : significanceModels = std::vector< double> (k, 1.0);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L190" > < span class = "lineNum" > 190< / span > : // 4. Move first n classifiers to models_bak< / span >
< span id = "L191" > < span class = "lineNum" > 191< / span > : // backup the first n_models - k models (if n_models == k, don't backup any)< / span >
2024-04-30 12:00:24 +00:00
< span id = "L192" > < span class = "lineNum" > 192< / span > < span class = "tlaGNC" > 222 : for (int i = 0; i < n_models - k; ++i) {< / span > < / span >
< span id = "L193" > < span class = "lineNum" > 193< / span > < span class = "tlaGNC" > 180 : model = std::move(models[0]);< / span > < / span >
< span id = "L194" > < span class = "lineNum" > 194< / span > < span class = "tlaGNC" > 180 : models.erase(models.begin());< / span > < / span >
< span id = "L195" > < span class = "lineNum" > 195< / span > < span class = "tlaGNC" > 180 : models_bak.push_back(std::move(model));< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L196" > < span class = "lineNum" > 196< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L197" > < span class = "lineNum" > 197< / span > < span class = "tlaGNC" > 42 : assert(models.size() == k);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L198" > < span class = "lineNum" > 198< / span > : // 5. n_models < - k< / span >
2024-04-30 12:00:24 +00:00
< span id = "L199" > < span class = "lineNum" > 199< / span > < span class = "tlaGNC" > 42 : n_models = k;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L200" > < span class = "lineNum" > 200< / span > : // 6. Make prediction, compute alpha, update weights< / span >
2024-04-30 12:00:24 +00:00
< span id = "L201" > < span class = "lineNum" > 201< / span > < span class = "tlaGNC" > 42 : auto ypred = predict(X_train);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L202" > < span class = "lineNum" > 202< / span > : //< / span >
< span id = "L203" > < span class = "lineNum" > 203< / span > : // Update weights< / span >
< span id = "L204" > < span class = "lineNum" > 204< / span > : //< / span >
< span id = "L205" > < span class = "lineNum" > 205< / span > : double alpha_t;< / span >
< span id = "L206" > < span class = "lineNum" > 206< / span > : bool terminate;< / span >
2024-04-30 12:00:24 +00:00
< span id = "L207" > < span class = "lineNum" > 207< / span > < span class = "tlaGNC" > 42 : std::tie(weights, alpha_t, terminate) = update_weights(y_train, ypred, weights);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L208" > < span class = "lineNum" > 208< / span > : //< / span >
< span id = "L209" > < span class = "lineNum" > 209< / span > : // Restore the models if needed< / span >
< span id = "L210" > < span class = "lineNum" > 210< / span > : //< / span >
< span id = "L211" > < span class = "lineNum" > 211< / span > : // 7. Insert classifiers in models_bak to be the first n models< / span >
< span id = "L212" > < span class = "lineNum" > 212< / span > : // if n_models_bak == k, don't restore any, because none of them were moved< / span >
2024-04-30 12:00:24 +00:00
< span id = "L213" > < span class = "lineNum" > 213< / span > < span class = "tlaGNC" > 42 : if (k != n_models_bak) {< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L214" > < span class = "lineNum" > 214< / span > : // Insert in the same order as they were extracted< / span >
2024-04-30 12:00:24 +00:00
< span id = "L215" > < span class = "lineNum" > 215< / span > < span class = "tlaGNC" > 36 : int bak_size = models_bak.size();< / span > < / span >
< span id = "L216" > < span class = "lineNum" > 216< / span > < span class = "tlaGNC" > 216 : for (int i = 0; i < bak_size; ++i) {< / span > < / span >
< span id = "L217" > < span class = "lineNum" > 217< / span > < span class = "tlaGNC" > 180 : model = std::move(models_bak[bak_size - 1 - i]);< / span > < / span >
< span id = "L218" > < span class = "lineNum" > 218< / span > < span class = "tlaGNC" > 180 : models_bak.erase(models_bak.end() - 1);< / span > < / span >
< span id = "L219" > < span class = "lineNum" > 219< / span > < span class = "tlaGNC" > 180 : models.insert(models.begin(), std::move(model));< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L220" > < span class = "lineNum" > 220< / span > : }< / span >
< span id = "L221" > < span class = "lineNum" > 221< / span > : }< / span >
< span id = "L222" > < span class = "lineNum" > 222< / span > : // 8. significances < - significances_bak< / span >
2024-04-30 12:00:24 +00:00
< span id = "L223" > < span class = "lineNum" > 223< / span > < span class = "tlaGNC" > 42 : significanceModels = significance_bak;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L224" > < span class = "lineNum" > 224< / span > : //< / span >
< span id = "L225" > < span class = "lineNum" > 225< / span > : // Update the significance of the last k models< / span >
< span id = "L226" > < span class = "lineNum" > 226< / span > : //< / span >
< span id = "L227" > < span class = "lineNum" > 227< / span > : // 9. Update last k significances< / span >
2024-04-30 12:00:24 +00:00
< span id = "L228" > < span class = "lineNum" > 228< / span > < span class = "tlaGNC" > 156 : for (int i = 0; i < k; ++i) {< / span > < / span >
< span id = "L229" > < span class = "lineNum" > 229< / span > < span class = "tlaGNC" > 114 : significanceModels[n_models_bak - k + i] = alpha_t;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L230" > < span class = "lineNum" > 230< / span > : }< / span >
< span id = "L231" > < span class = "lineNum" > 231< / span > : // 10. n_models < - n_models_bak< / span >
2024-04-30 12:00:24 +00:00
< span id = "L232" > < span class = "lineNum" > 232< / span > < span class = "tlaGNC" > 42 : n_models = n_models_bak;< / span > < / span >
< span id = "L233" > < span class = "lineNum" > 233< / span > < span class = "tlaGNC" > 84 : return { weights, alpha_t, terminate };< / span > < / span >
< span id = "L234" > < span class = "lineNum" > 234< / span > < span class = "tlaGNC" > 42 : }< / span > < / span >
< span id = "L235" > < span class = "lineNum" > 235< / span > < span class = "tlaGNC" > 48 : std::vector< int> BoostAODE::initializeModels()< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L236" > < span class = "lineNum" > 236< / span > : {< / span >
2024-04-30 12:00:24 +00:00
< span id = "L237" > < span class = "lineNum" > 237< / span > < span class = "tlaGNC" > 48 : std::vector< int> featuresUsed;< / span > < / span >
< span id = "L238" > < span class = "lineNum" > 238< / span > < span class = "tlaGNC" > 48 : torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);< / span > < / span >
< span id = "L239" > < span class = "lineNum" > 239< / span > < span class = "tlaGNC" > 48 : int maxFeatures = 0;< / span > < / span >
< span id = "L240" > < span class = "lineNum" > 240< / span > < span class = "tlaGNC" > 48 : if (select_features_algorithm == SelectFeatures.CFS) {< / span > < / span >
< span id = "L241" > < span class = "lineNum" > 241< / span > < span class = "tlaGNC" > 12 : featureSelector = new CFS(dataset, features, className, maxFeatures, states.at(className).size(), weights_);< / span > < / span >
< span id = "L242" > < span class = "lineNum" > 242< / span > < span class = "tlaGNC" > 36 : } else if (select_features_algorithm == SelectFeatures.IWSS) {< / span > < / span >
< span id = "L243" > < span class = "lineNum" > 243< / span > < span class = "tlaGNC" > 18 : if (threshold < 0 || threshold > 0.5) {< / span > < / span >
< span id = "L244" > < span class = "lineNum" > 244< / span > < span class = "tlaGNC" > 12 : throw std::invalid_argument(" Invalid threshold value for " + SelectFeatures.IWSS + " [0, 0.5]" );< / span > < / span >
2024-04-21 14:44:35 +00:00
< span id = "L245" > < span class = "lineNum" > 245< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L246" > < span class = "lineNum" > 246< / span > < span class = "tlaGNC" > 6 : featureSelector = new IWSS(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold);< / span > < / span >
< span id = "L247" > < span class = "lineNum" > 247< / span > < span class = "tlaGNC" > 18 : } else if (select_features_algorithm == SelectFeatures.FCBF) {< / span > < / span >
< span id = "L248" > < span class = "lineNum" > 248< / span > < span class = "tlaGNC" > 18 : if (threshold < 1e-7 || threshold > 1) {< / span > < / span >
< span id = "L249" > < span class = "lineNum" > 249< / span > < span class = "tlaGNC" > 12 : throw std::invalid_argument(" Invalid threshold value for " + SelectFeatures.FCBF + " [1e-7, 1]" );< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L250" > < span class = "lineNum" > 250< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L251" > < span class = "lineNum" > 251< / span > < span class = "tlaGNC" > 6 : featureSelector = new FCBF(dataset, features, className, maxFeatures, states.at(className).size(), weights_, threshold);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L252" > < span class = "lineNum" > 252< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L253" > < span class = "lineNum" > 253< / span > < span class = "tlaGNC" > 24 : featureSelector-> fit();< / span > < / span >
< span id = "L254" > < span class = "lineNum" > 254< / span > < span class = "tlaGNC" > 24 : auto cfsFeatures = featureSelector-> getFeatures();< / span > < / span >
< span id = "L255" > < span class = "lineNum" > 255< / span > < span class = "tlaGNC" > 24 : auto scores = featureSelector-> getScores();< / span > < / span >
< span id = "L256" > < span class = "lineNum" > 256< / span > < span class = "tlaGNC" > 150 : for (const int& feature : cfsFeatures) {< / span > < / span >
< span id = "L257" > < span class = "lineNum" > 257< / span > < span class = "tlaGNC" > 126 : featuresUsed.push_back(feature);< / span > < / span >
< span id = "L258" > < span class = "lineNum" > 258< / span > < span class = "tlaGNC" > 126 : std::unique_ptr< Classifier> model = std::make_unique< SPODE> (feature);< / span > < / span >
< span id = "L259" > < span class = "lineNum" > 259< / span > < span class = "tlaGNC" > 126 : model-> fit(dataset, features, className, states, weights_);< / span > < / span >
< span id = "L260" > < span class = "lineNum" > 260< / span > < span class = "tlaGNC" > 126 : models.push_back(std::move(model));< / span > < / span >
< span id = "L261" > < span class = "lineNum" > 261< / span > < span class = "tlaGNC" > 126 : significanceModels.push_back(1.0); // They will be updated later in trainModel< / span > < / span >
< span id = "L262" > < span class = "lineNum" > 262< / span > < span class = "tlaGNC" > 126 : n_models++;< / span > < / span >
< span id = "L263" > < span class = "lineNum" > 263< / span > < span class = "tlaGNC" > 126 : }< / span > < / span >
< span id = "L264" > < span class = "lineNum" > 264< / span > < span class = "tlaGNC" > 24 : notes.push_back(" Used features in initialization: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()) + " with " + select_features_algorithm);< / span > < / span >
< span id = "L265" > < span class = "lineNum" > 265< / span > < span class = "tlaGNC" > 24 : delete featureSelector;< / span > < / span >
< span id = "L266" > < span class = "lineNum" > 266< / span > < span class = "tlaGNC" > 48 : return featuresUsed;< / span > < / span >
< span id = "L267" > < span class = "lineNum" > 267< / span > < span class = "tlaGNC" > 72 : }< / span > < / span >
< span id = "L268" > < span class = "lineNum" > 268< / span > < span class = "tlaGNC" > 138 : void BoostAODE::trainModel(const torch::Tensor& weights)< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L269" > < span class = "lineNum" > 269< / span > : {< / span >
< span id = "L270" > < span class = "lineNum" > 270< / span > : //< / span >
< span id = "L271" > < span class = "lineNum" > 271< / span > : // Logging setup< / span >
< span id = "L272" > < span class = "lineNum" > 272< / span > : //< / span >
2024-04-30 12:00:24 +00:00
< span id = "L273" > < span class = "lineNum" > 273< / span > < span class = "tlaGNC" > 138 : loguru::set_thread_name(" BoostAODE" );< / span > < / span >
< span id = "L274" > < span class = "lineNum" > 274< / span > < span class = "tlaGNC" > 138 : loguru::g_stderr_verbosity = loguru::Verbosity_OFF;< / span > < / span >
< span id = "L275" > < span class = "lineNum" > 275< / span > < span class = "tlaGNC" > 138 : loguru::add_file(" boostAODE.log" , loguru::Truncate, loguru::Verbosity_MAX);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L276" > < span class = "lineNum" > 276< / span > : < / span >
< span id = "L277" > < span class = "lineNum" > 277< / span > : // Algorithm based on the adaboost algorithm for classification< / span >
< span id = "L278" > < span class = "lineNum" > 278< / span > : // as explained in Ensemble methods (Zhi-Hua Zhou, 2012)< / span >
2024-04-30 12:00:24 +00:00
< span id = "L279" > < span class = "lineNum" > 279< / span > < span class = "tlaGNC" > 138 : fitted = true;< / span > < / span >
< span id = "L280" > < span class = "lineNum" > 280< / span > < span class = "tlaGNC" > 138 : double alpha_t = 0;< / span > < / span >
< span id = "L281" > < span class = "lineNum" > 281< / span > < span class = "tlaGNC" > 138 : torch::Tensor weights_ = torch::full({ m }, 1.0 / m, torch::kFloat64);< / span > < / span >
< span id = "L282" > < span class = "lineNum" > 282< / span > < span class = "tlaGNC" > 138 : bool finished = false;< / span > < / span >
< span id = "L283" > < span class = "lineNum" > 283< / span > < span class = "tlaGNC" > 138 : std::vector< int> featuresUsed;< / span > < / span >
< span id = "L284" > < span class = "lineNum" > 284< / span > < span class = "tlaGNC" > 138 : if (selectFeatures) {< / span > < / span >
< span id = "L285" > < span class = "lineNum" > 285< / span > < span class = "tlaGNC" > 48 : featuresUsed = initializeModels();< / span > < / span >
< span id = "L286" > < span class = "lineNum" > 286< / span > < span class = "tlaGNC" > 24 : auto ypred = predict(X_train);< / span > < / span >
< span id = "L287" > < span class = "lineNum" > 287< / span > < span class = "tlaGNC" > 24 : std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L288" > < span class = "lineNum" > 288< / span > : // Update significance of the models< / span >
2024-04-30 12:00:24 +00:00
< span id = "L289" > < span class = "lineNum" > 289< / span > < span class = "tlaGNC" > 150 : for (int i = 0; i < n_models; ++i) {< / span > < / span >
< span id = "L290" > < span class = "lineNum" > 290< / span > < span class = "tlaGNC" > 126 : significanceModels[i] = alpha_t;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L291" > < span class = "lineNum" > 291< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L292" > < span class = "lineNum" > 292< / span > < span class = "tlaGNC" > 24 : if (finished) {< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L293" > < span class = "lineNum" > 293< / span > < span class = "tlaUNC tlaBgUNC" > 0 : return;< / span > < / span >
< span id = "L294" > < span class = "lineNum" > 294< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L295" > < span class = "lineNum" > 295< / span > < span class = "tlaGNC tlaBgGNC" > 24 : }< / span > < / span >
< span id = "L296" > < span class = "lineNum" > 296< / span > < span class = "tlaGNC" > 114 : int numItemsPack = 0; // The counter of the models inserted in the current pack< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L297" > < span class = "lineNum" > 297< / span > : // Variables to control the accuracy finish condition< / span >
2024-04-30 12:00:24 +00:00
< span id = "L298" > < span class = "lineNum" > 298< / span > < span class = "tlaGNC" > 114 : double priorAccuracy = 0.0;< / span > < / span >
< span id = "L299" > < span class = "lineNum" > 299< / span > < span class = "tlaGNC" > 114 : double improvement = 1.0;< / span > < / span >
< span id = "L300" > < span class = "lineNum" > 300< / span > < span class = "tlaGNC" > 114 : double convergence_threshold = 1e-4;< / span > < / span >
< span id = "L301" > < span class = "lineNum" > 301< / span > < span class = "tlaGNC" > 114 : int tolerance = 0; // number of times the accuracy is lower than the convergence_threshold< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L302" > < span class = "lineNum" > 302< / span > : // Step 0: Set the finish condition< / span >
< span id = "L303" > < span class = "lineNum" > 303< / span > : // epsilon sub t > 0.5 => inverse the weights policy< / span >
< span id = "L304" > < span class = "lineNum" > 304< / span > : // validation error is not decreasing< / span >
< span id = "L305" > < span class = "lineNum" > 305< / span > : // run out of features< / span >
2024-04-30 12:00:24 +00:00
< span id = "L306" > < span class = "lineNum" > 306< / span > < span class = "tlaGNC" > 114 : bool ascending = order_algorithm == Orders.ASC;< / span > < / span >
< span id = "L307" > < span class = "lineNum" > 307< / span > < span class = "tlaGNC" > 114 : std::mt19937 g{ 173 };< / span > < / span >
< span id = "L308" > < span class = "lineNum" > 308< / span > < span class = "tlaGNC" > 756 : while (!finished) {< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L309" > < span class = "lineNum" > 309< / span > : // Step 1: Build ranking with mutual information< / span >
2024-04-30 12:00:24 +00:00
< span id = "L310" > < span class = "lineNum" > 310< / span > < span class = "tlaGNC" > 642 : auto featureSelection = metrics.SelectKBestWeighted(weights_, ascending, n); // Get all the features sorted< / span > < / span >
< span id = "L311" > < span class = "lineNum" > 311< / span > < span class = "tlaGNC" > 642 : if (order_algorithm == Orders.RAND) {< / span > < / span >
< span id = "L312" > < span class = "lineNum" > 312< / span > < span class = "tlaGNC" > 54 : std::shuffle(featureSelection.begin(), featureSelection.end(), g);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L313" > < span class = "lineNum" > 313< / span > : }< / span >
< span id = "L314" > < span class = "lineNum" > 314< / span > : // Remove used features< / span >
2024-04-30 12:00:24 +00:00
< span id = "L315" > < span class = "lineNum" > 315< / span > < span class = "tlaGNC" > 1284 : featureSelection.erase(remove_if(begin(featureSelection), end(featureSelection), [& ](auto x)< / span > < / span >
< span id = "L316" > < span class = "lineNum" > 316< / span > < span class = "tlaGNC" > 58200 : { return std::find(begin(featuresUsed), end(featuresUsed), x) != end(featuresUsed);}),< / span > < / span >
< span id = "L317" > < span class = "lineNum" > 317< / span > < span class = "tlaGNC" > 642 : end(featureSelection)< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L318" > < span class = "lineNum" > 318< / span > : );< / span >
2024-04-30 12:00:24 +00:00
< span id = "L319" > < span class = "lineNum" > 319< / span > < span class = "tlaGNC" > 642 : int k = bisection ? pow(2, tolerance) : 1;< / span > < / span >
< span id = "L320" > < span class = "lineNum" > 320< / span > < span class = "tlaGNC" > 642 : int counter = 0; // The model counter of the current pack< / span > < / span >
< span id = "L321" > < span class = "lineNum" > 321< / span > < span class = "tlaGNC" > 642 : VLOG_SCOPE_F(1, " counter=%d k=%d featureSelection.size: %zu" , counter, k, featureSelection.size());< / span > < / span >
< span id = "L322" > < span class = "lineNum" > 322< / span > < span class = "tlaGNC" > 1506 : while (counter++ < k & & featureSelection.size() > 0) {< / span > < / span >
< span id = "L323" > < span class = "lineNum" > 323< / span > < span class = "tlaGNC" > 864 : auto feature = featureSelection[0];< / span > < / span >
< span id = "L324" > < span class = "lineNum" > 324< / span > < span class = "tlaGNC" > 864 : featureSelection.erase(featureSelection.begin());< / span > < / span >
< span id = "L325" > < span class = "lineNum" > 325< / span > < span class = "tlaGNC" > 864 : std::unique_ptr< Classifier> model;< / span > < / span >
< span id = "L326" > < span class = "lineNum" > 326< / span > < span class = "tlaGNC" > 864 : model = std::make_unique< SPODE> (feature);< / span > < / span >
< span id = "L327" > < span class = "lineNum" > 327< / span > < span class = "tlaGNC" > 864 : model-> fit(dataset, features, className, states, weights_);< / span > < / span >
< span id = "L328" > < span class = "lineNum" > 328< / span > < span class = "tlaGNC" > 864 : alpha_t = 0.0;< / span > < / span >
< span id = "L329" > < span class = "lineNum" > 329< / span > < span class = "tlaGNC" > 864 : if (!block_update) {< / span > < / span >
< span id = "L330" > < span class = "lineNum" > 330< / span > < span class = "tlaGNC" > 750 : auto ypred = model-> predict(X_train);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L331" > < span class = "lineNum" > 331< / span > : // Step 3.1: Compute the classifier amout of say< / span >
2024-04-30 12:00:24 +00:00
< span id = "L332" > < span class = "lineNum" > 332< / span > < span class = "tlaGNC" > 750 : std::tie(weights_, alpha_t, finished) = update_weights(y_train, ypred, weights_);< / span > < / span >
< span id = "L333" > < span class = "lineNum" > 333< / span > < span class = "tlaGNC" > 750 : }< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L334" > < span class = "lineNum" > 334< / span > : // Step 3.4: Store classifier and its accuracy to weigh its future vote< / span >
2024-04-30 12:00:24 +00:00
< span id = "L335" > < span class = "lineNum" > 335< / span > < span class = "tlaGNC" > 864 : numItemsPack++;< / span > < / span >
< span id = "L336" > < span class = "lineNum" > 336< / span > < span class = "tlaGNC" > 864 : featuresUsed.push_back(feature);< / span > < / span >
< span id = "L337" > < span class = "lineNum" > 337< / span > < span class = "tlaGNC" > 864 : models.push_back(std::move(model));< / span > < / span >
< span id = "L338" > < span class = "lineNum" > 338< / span > < span class = "tlaGNC" > 864 : significanceModels.push_back(alpha_t);< / span > < / span >
< span id = "L339" > < span class = "lineNum" > 339< / span > < span class = "tlaGNC" > 864 : n_models++;< / span > < / span >
< span id = "L340" > < span class = "lineNum" > 340< / span > < span class = "tlaGNC" > 864 : VLOG_SCOPE_F(2, " numItemsPack: %d n_models: %d featuresUsed: %zu" , numItemsPack, n_models, featuresUsed.size());< / span > < / span >
< span id = "L341" > < span class = "lineNum" > 341< / span > < span class = "tlaGNC" > 864 : }< / span > < / span >
< span id = "L342" > < span class = "lineNum" > 342< / span > < span class = "tlaGNC" > 642 : if (block_update) {< / span > < / span >
< span id = "L343" > < span class = "lineNum" > 343< / span > < span class = "tlaGNC" > 42 : std::tie(weights_, alpha_t, finished) = update_weights_block(k, y_train, weights_);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L344" > < span class = "lineNum" > 344< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L345" > < span class = "lineNum" > 345< / span > < span class = "tlaGNC" > 642 : if (convergence & & !finished) {< / span > < / span >
< span id = "L346" > < span class = "lineNum" > 346< / span > < span class = "tlaGNC" > 444 : auto y_val_predict = predict(X_test);< / span > < / span >
< span id = "L347" > < span class = "lineNum" > 347< / span > < span class = "tlaGNC" > 444 : double accuracy = (y_val_predict == y_test).sum().item< double> () / (double)y_test.size(0);< / span > < / span >
< span id = "L348" > < span class = "lineNum" > 348< / span > < span class = "tlaGNC" > 444 : if (priorAccuracy == 0) {< / span > < / span >
< span id = "L349" > < span class = "lineNum" > 349< / span > < span class = "tlaGNC" > 90 : priorAccuracy = accuracy;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L350" > < span class = "lineNum" > 350< / span > : } else {< / span >
2024-04-30 12:00:24 +00:00
< span id = "L351" > < span class = "lineNum" > 351< / span > < span class = "tlaGNC" > 354 : improvement = accuracy - priorAccuracy;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L352" > < span class = "lineNum" > 352< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L353" > < span class = "lineNum" > 353< / span > < span class = "tlaGNC" > 444 : if (improvement < convergence_threshold) {< / span > < / span >
< span id = "L354" > < span class = "lineNum" > 354< / span > < span class = "tlaGNC" > 264 : VLOG_SCOPE_F(3, " (improvement< threshold) tolerance: %d numItemsPack: %d improvement: %f prior: %f current: %f" , tolerance, numItemsPack, improvement, priorAccuracy, accuracy);< / span > < / span >
< span id = "L355" > < span class = "lineNum" > 355< / span > < span class = "tlaGNC" > 264 : tolerance++;< / span > < / span >
< span id = "L356" > < span class = "lineNum" > 356< / span > < span class = "tlaGNC" > 264 : } else {< / span > < / span >
< span id = "L357" > < span class = "lineNum" > 357< / span > < span class = "tlaGNC" > 180 : VLOG_SCOPE_F(3, " * (improvement> =threshold) Reset. tolerance: %d numItemsPack: %d improvement: %f prior: %f current: %f" , tolerance, numItemsPack, improvement, priorAccuracy, accuracy);< / span > < / span >
< span id = "L358" > < span class = "lineNum" > 358< / span > < span class = "tlaGNC" > 180 : tolerance = 0; // Reset the counter if the model performs better< / span > < / span >
< span id = "L359" > < span class = "lineNum" > 359< / span > < span class = "tlaGNC" > 180 : numItemsPack = 0;< / span > < / span >
< span id = "L360" > < span class = "lineNum" > 360< / span > < span class = "tlaGNC" > 180 : }< / span > < / span >
< span id = "L361" > < span class = "lineNum" > 361< / span > < span class = "tlaGNC" > 444 : if (convergence_best) {< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L362" > < span class = "lineNum" > 362< / span > : // Keep the best accuracy until now as the prior accuracy< / span >
2024-04-30 12:00:24 +00:00
< span id = "L363" > < span class = "lineNum" > 363< / span > < span class = "tlaGNC" > 48 : priorAccuracy = std::max(accuracy, priorAccuracy);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L364" > < span class = "lineNum" > 364< / span > : } else {< / span >
< span id = "L365" > < span class = "lineNum" > 365< / span > : // Keep the last accuray obtained as the prior accuracy< / span >
2024-04-30 12:00:24 +00:00
< span id = "L366" > < span class = "lineNum" > 366< / span > < span class = "tlaGNC" > 396 : priorAccuracy = accuracy;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L367" > < span class = "lineNum" > 367< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L368" > < span class = "lineNum" > 368< / span > < span class = "tlaGNC" > 444 : }< / span > < / span >
< span id = "L369" > < span class = "lineNum" > 369< / span > < span class = "tlaGNC" > 642 : VLOG_SCOPE_F(1, " tolerance: %d featuresUsed.size: %zu features.size: %zu" , tolerance, featuresUsed.size(), features.size());< / span > < / span >
< span id = "L370" > < span class = "lineNum" > 370< / span > < span class = "tlaGNC" > 642 : finished = finished || tolerance > maxTolerance || featuresUsed.size() == features.size();< / span > < / span >
< span id = "L371" > < span class = "lineNum" > 371< / span > < span class = "tlaGNC" > 642 : }< / span > < / span >
< span id = "L372" > < span class = "lineNum" > 372< / span > < span class = "tlaGNC" > 114 : if (tolerance > maxTolerance) {< / span > < / span >
< span id = "L373" > < span class = "lineNum" > 373< / span > < span class = "tlaGNC" > 12 : if (numItemsPack < n_models) {< / span > < / span >
< span id = "L374" > < span class = "lineNum" > 374< / span > < span class = "tlaGNC" > 12 : notes.push_back(" Convergence threshold reached & " + std::to_string(numItemsPack) + " models eliminated" );< / span > < / span >
< span id = "L375" > < span class = "lineNum" > 375< / span > < span class = "tlaGNC" > 12 : VLOG_SCOPE_F(4, " Convergence threshold reached & %d models eliminated of %d" , numItemsPack, n_models);< / span > < / span >
< span id = "L376" > < span class = "lineNum" > 376< / span > < span class = "tlaGNC" > 156 : for (int i = 0; i < numItemsPack; ++i) {< / span > < / span >
< span id = "L377" > < span class = "lineNum" > 377< / span > < span class = "tlaGNC" > 144 : significanceModels.pop_back();< / span > < / span >
< span id = "L378" > < span class = "lineNum" > 378< / span > < span class = "tlaGNC" > 144 : models.pop_back();< / span > < / span >
< span id = "L379" > < span class = "lineNum" > 379< / span > < span class = "tlaGNC" > 144 : n_models--;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L380" > < span class = "lineNum" > 380< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L381" > < span class = "lineNum" > 381< / span > < span class = "tlaGNC" > 12 : } else {< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L382" > < span class = "lineNum" > 382< / span > < span class = "tlaUNC tlaBgUNC" > 0 : notes.push_back(" Convergence threshold reached & 0 models eliminated" );< / span > < / span >
< span id = "L383" > < span class = "lineNum" > 383< / span > < span class = "tlaUNC" > 0 : VLOG_SCOPE_F(4, " Convergence threshold reached & 0 models eliminated n_models=%d numItemsPack=%d" , n_models, numItemsPack);< / span > < / span >
< span id = "L384" > < span class = "lineNum" > 384< / span > < span class = "tlaUNC" > 0 : }< / span > < / span >
< span id = "L385" > < span class = "lineNum" > 385< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L386" > < span class = "lineNum" > 386< / span > < span class = "tlaGNC tlaBgGNC" > 114 : if (featuresUsed.size() != features.size()) {< / span > < / span >
< span id = "L387" > < span class = "lineNum" > 387< / span > < span class = "tlaGNC" > 6 : notes.push_back(" Used features in train: " + std::to_string(featuresUsed.size()) + " of " + std::to_string(features.size()));< / span > < / span >
< span id = "L388" > < span class = "lineNum" > 388< / span > < span class = "tlaGNC" > 6 : status = WARNING;< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L389" > < span class = "lineNum" > 389< / span > : }< / span >
2024-04-30 12:00:24 +00:00
< span id = "L390" > < span class = "lineNum" > 390< / span > < span class = "tlaGNC" > 114 : notes.push_back(" Number of models: " + std::to_string(n_models));< / span > < / span >
< span id = "L391" > < span class = "lineNum" > 391< / span > < span class = "tlaGNC" > 162 : }< / span > < / span >
< span id = "L392" > < span class = "lineNum" > 392< / span > < span class = "tlaGNC" > 6 : std::vector< std::string> BoostAODE::graph(const std::string& title) const< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L393" > < span class = "lineNum" > 393< / span > : {< / span >
2024-04-30 12:00:24 +00:00
< span id = "L394" > < span class = "lineNum" > 394< / span > < span class = "tlaGNC" > 6 : return Ensemble::graph(title);< / span > < / span >
2024-04-29 22:52:09 +00:00
< span id = "L395" > < span class = "lineNum" > 395< / span > : }< / span >
< span id = "L396" > < span class = "lineNum" > 396< / span > : }< / span >
2024-04-21 14:44:35 +00:00
< / pre >
< / td >
< / tr >
< / table >
< br >
< table width = "100%" border = 0 cellspacing = 0 cellpadding = 0 >
< tr > < td class = "ruler" > < img src = "../../glass.png" width = 3 height = 3 alt = "" > < / td > < / tr >
< tr > < td class = "versionInfo" > Generated by: < a href = "https://github.com//linux-test-project/lcov" target = "_parent" > LCOV version 2.0-1< / a > < / td > < / tr >
< / table >
< br >
< / body >
< / html >