LCOV - code coverage report
Current view: top level - bayesnet/feature_selection - FCBF.cc (source / functions) Coverage Total Hit
Test: coverage.info Lines: 92.3 % 26 24
Test Date: 2024-04-30 13:59:18 Functions: 100.0 % 2 2

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include "bayesnet/utils/bayesnetUtils.h"
       8              : #include "FCBF.h"
       9              : namespace bayesnet {
      10              : 
      11           48 :     FCBF::FCBF(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int maxFeatures, const int classNumStates, const torch::Tensor& weights, const double threshold) :
      12           48 :         FeatureSelect(samples, features, className, maxFeatures, classNumStates, weights), threshold(threshold)
      13              :     {
      14           48 :         if (threshold < 1e-7) {
      15           14 :             throw std::invalid_argument("Threshold cannot be less than 1e-7");
      16              :         }
      17           48 :     }
      18           34 :     void FCBF::fit()
      19              :     {
      20           34 :         initialize();
      21           34 :         computeSuLabels();
      22           34 :         auto featureOrder = argsort(suLabels); // sort descending order
      23           34 :         auto featureOrderCopy = featureOrder;
      24          284 :         for (const auto& feature : featureOrder) {
      25              :             // Don't self compare
      26          250 :             featureOrderCopy.erase(featureOrderCopy.begin());
      27          250 :             if (suLabels.at(feature) == 0.0) {
      28              :                 // The feature has been removed from the list
      29          108 :                 continue;
      30              :             }
      31          142 :             if (suLabels.at(feature) < threshold) {
      32            0 :                 break;
      33              :             }
      34              :             // Remove redundant features
      35          781 :             for (const auto& featureCopy : featureOrderCopy) {
      36          639 :                 double value = computeSuFeatures(feature, featureCopy);
      37          639 :                 if (value >= suLabels.at(featureCopy)) {
      38              :                     // Remove feature from list
      39          221 :                     suLabels[featureCopy] = 0.0;
      40              :                 }
      41              :             }
      42          142 :             selectedFeatures.push_back(feature);
      43          142 :             selectedScores.push_back(suLabels[feature]);
      44          142 :             if (selectedFeatures.size() == maxFeatures) {
      45            0 :                 break;
      46              :             }
      47              :         }
      48           34 :         fitted = true;
      49           34 :     }
      50              : }
        

Generated by: LCOV version 2.0-1