9.4 KiB
9.4 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #include "bayesnet/utils/bayesnetUtils.h" 8 : #include "FCBF.h" 9 : namespace bayesnet { 10 : 11 78 : FCBF::FCBF(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int maxFeatures, const int classNumStates, const torch::Tensor& weights, const double threshold) : 12 78 : FeatureSelect(samples, features, className, maxFeatures, classNumStates, weights), threshold(threshold) 13 : { 14 78 : if (threshold < 1e-7) { 15 22 : throw std::invalid_argument("Threshold cannot be less than 1e-7"); 16 : } 17 78 : } 18 56 : void FCBF::fit() 19 : { 20 56 : initialize(); 21 56 : computeSuLabels(); 22 56 : auto featureOrder = argsort(suLabels); // sort descending order 23 56 : auto featureOrderCopy = featureOrder; 24 472 : for (const auto& feature : featureOrder) { 25 : // Don't self compare 26 416 : featureOrderCopy.erase(featureOrderCopy.begin()); 27 416 : if (suLabels.at(feature) == 0.0) { 28 : // The feature has been removed from the list 29 180 : continue; 30 : } 31 236 : if (suLabels.at(feature) < threshold) { 32 0 : break; 33 : } 34 : // Remove redundant features 35 1307 : for (const auto& featureCopy : featureOrderCopy) { 36 1071 : double value = computeSuFeatures(feature, featureCopy); 37 1071 : if (value >= suLabels.at(featureCopy)) { 38 : // Remove feature from list 39 373 : suLabels[featureCopy] = 0.0; 40 : } 41 : } 42 236 : selectedFeatures.push_back(feature); 43 236 : selectedScores.push_back(suLabels[feature]); 44 236 : if (selectedFeatures.size() == maxFeatures) { 45 0 : break; 46 : } 47 : } 48 56 : fitted = true; 49 56 : } 50 : } |
![]() |
Generated by: LCOV version 2.0-1 |
</html>