Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include <limits>
8 : #include "bayesnet/utils/bayesnetUtils.h"
9 : #include "CFS.h"
10 : namespace bayesnet {
11 40 : void CFS::fit()
12 : {
13 40 : initialize();
14 40 : computeSuLabels();
15 40 : auto featureOrder = argsort(suLabels); // sort descending order
16 40 : auto continueCondition = true;
17 40 : auto feature = featureOrder[0];
18 40 : selectedFeatures.push_back(feature);
19 40 : selectedScores.push_back(suLabels[feature]);
20 40 : featureOrder.erase(featureOrder.begin());
21 226 : while (continueCondition) {
22 186 : double merit = std::numeric_limits<double>::lowest();
23 186 : int bestFeature = -1;
24 1083 : for (auto feature : featureOrder) {
25 897 : selectedFeatures.push_back(feature);
26 : // Compute merit with selectedFeatures
27 897 : auto meritNew = computeMeritCFS();
28 897 : if (meritNew > merit) {
29 379 : merit = meritNew;
30 379 : bestFeature = feature;
31 : }
32 897 : selectedFeatures.pop_back();
33 : }
34 186 : if (bestFeature == -1) {
35 : // meritNew has to be nan due to constant features
36 0 : break;
37 : }
38 186 : selectedFeatures.push_back(bestFeature);
39 186 : selectedScores.push_back(merit);
40 186 : featureOrder.erase(remove(featureOrder.begin(), featureOrder.end(), bestFeature), featureOrder.end());
41 186 : continueCondition = computeContinueCondition(featureOrder);
42 : }
43 40 : fitted = true;
44 40 : }
45 186 : bool CFS::computeContinueCondition(const std::vector<int>& featureOrder)
46 : {
47 186 : if (selectedFeatures.size() == maxFeatures || featureOrder.size() == 0) {
48 7 : return false;
49 : }
50 179 : if (selectedScores.size() >= 5) {
51 : /*
52 : "To prevent the best first search from exploring the entire
53 : feature subset search space, a stopping criterion is imposed.
54 : The search will terminate if five consecutive fully expanded
55 : subsets show no improvement over the current best subset."
56 : as stated in Mark A.Hall Thesis
57 : */
58 66 : double item_ant = std::numeric_limits<double>::lowest();
59 66 : int num = 0;
60 66 : std::vector<double> lastFive(selectedScores.end() - 5, selectedScores.end());
61 264 : for (auto item : lastFive) {
62 231 : if (item_ant == std::numeric_limits<double>::lowest()) {
63 66 : item_ant = item;
64 : }
65 231 : if (item > item_ant) {
66 33 : break;
67 : } else {
68 198 : num++;
69 198 : item_ant = item;
70 : }
71 : }
72 66 : if (num == 5) {
73 33 : return false;
74 : }
75 66 : }
76 146 : return true;
77 : }
78 : }
|