Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "bayesnet/utils/bayesnetUtils.h"
8 : #include "FCBF.h"
9 : namespace bayesnet {
10 :
11 14 : FCBF::FCBF(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int maxFeatures, const int classNumStates, const torch::Tensor& weights, const double threshold) :
12 14 : FeatureSelect(samples, features, className, maxFeatures, classNumStates, weights), threshold(threshold)
13 : {
14 14 : if (threshold < 1e-7) {
15 4 : throw std::invalid_argument("Threshold cannot be less than 1e-7");
16 : }
17 14 : }
18 10 : void FCBF::fit()
19 : {
20 10 : initialize();
21 10 : computeSuLabels();
22 10 : auto featureOrder = argsort(suLabels); // sort descending order
23 10 : auto featureOrderCopy = featureOrder;
24 84 : for (const auto& feature : featureOrder) {
25 : // Don't self compare
26 74 : featureOrderCopy.erase(featureOrderCopy.begin());
27 74 : if (suLabels.at(feature) == 0.0) {
28 : // The feature has been removed from the list
29 32 : continue;
30 : }
31 42 : if (suLabels.at(feature) < threshold) {
32 0 : break;
33 : }
34 : // Remove redundant features
35 232 : for (const auto& featureCopy : featureOrderCopy) {
36 190 : double value = computeSuFeatures(feature, featureCopy);
37 190 : if (value >= suLabels.at(featureCopy)) {
38 : // Remove feature from list
39 66 : suLabels[featureCopy] = 0.0;
40 : }
41 : }
42 42 : selectedFeatures.push_back(feature);
43 42 : selectedScores.push_back(suLabels[feature]);
44 42 : if (selectedFeatures.size() == maxFeatures) {
45 0 : break;
46 : }
47 : }
48 10 : fitted = true;
49 10 : }
50 : }
|