Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "bayesnet/utils/bayesnetUtils.h"
8 : #include "FCBF.h"
9 : namespace bayesnet {
10 :
11 48 : FCBF::FCBF(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int maxFeatures, const int classNumStates, const torch::Tensor& weights, const double threshold) :
12 48 : FeatureSelect(samples, features, className, maxFeatures, classNumStates, weights), threshold(threshold)
13 : {
14 48 : if (threshold < 1e-7) {
15 14 : throw std::invalid_argument("Threshold cannot be less than 1e-7");
16 : }
17 48 : }
18 34 : void FCBF::fit()
19 : {
20 34 : initialize();
21 34 : computeSuLabels();
22 34 : auto featureOrder = argsort(suLabels); // sort descending order
23 34 : auto featureOrderCopy = featureOrder;
24 284 : for (const auto& feature : featureOrder) {
25 : // Don't self compare
26 250 : featureOrderCopy.erase(featureOrderCopy.begin());
27 250 : if (suLabels.at(feature) == 0.0) {
28 : // The feature has been removed from the list
29 108 : continue;
30 : }
31 142 : if (suLabels.at(feature) < threshold) {
32 0 : break;
33 : }
34 : // Remove redundant features
35 781 : for (const auto& featureCopy : featureOrderCopy) {
36 639 : double value = computeSuFeatures(feature, featureCopy);
37 639 : if (value >= suLabels.at(featureCopy)) {
38 : // Remove feature from list
39 221 : suLabels[featureCopy] = 0.0;
40 : }
41 : }
42 142 : selectedFeatures.push_back(feature);
43 142 : selectedScores.push_back(suLabels[feature]);
44 142 : if (selectedFeatures.size() == maxFeatures) {
45 0 : break;
46 : }
47 : }
48 34 : fitted = true;
49 34 : }
50 : }
|