Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include <limits>
8 : #include "bayesnet/utils/bayesnetUtils.h"
9 : #include "CFS.h"
10 : namespace bayesnet {
11 48 : void CFS::fit()
12 : {
13 48 : initialize();
14 48 : computeSuLabels();
15 48 : auto featureOrder = argsort(suLabels); // sort descending order
16 48 : auto continueCondition = true;
17 48 : auto feature = featureOrder[0];
18 48 : selectedFeatures.push_back(feature);
19 48 : selectedScores.push_back(suLabels[feature]);
20 48 : featureOrder.erase(featureOrder.begin());
21 268 : while (continueCondition) {
22 220 : double merit = std::numeric_limits<double>::lowest();
23 220 : int bestFeature = -1;
24 1250 : for (auto feature : featureOrder) {
25 1030 : selectedFeatures.push_back(feature);
26 : // Compute merit with selectedFeatures
27 1030 : auto meritNew = computeMeritCFS();
28 1030 : if (meritNew > merit) {
29 450 : merit = meritNew;
30 450 : bestFeature = feature;
31 : }
32 1030 : selectedFeatures.pop_back();
33 : }
34 220 : if (bestFeature == -1) {
35 : // meritNew has to be nan due to constant features
36 0 : break;
37 : }
38 220 : selectedFeatures.push_back(bestFeature);
39 220 : selectedScores.push_back(merit);
40 220 : featureOrder.erase(remove(featureOrder.begin(), featureOrder.end(), bestFeature), featureOrder.end());
41 220 : continueCondition = computeContinueCondition(featureOrder);
42 : }
43 48 : fitted = true;
44 48 : }
45 220 : bool CFS::computeContinueCondition(const std::vector<int>& featureOrder)
46 : {
47 220 : if (selectedFeatures.size() == maxFeatures || featureOrder.size() == 0) {
48 10 : return false;
49 : }
50 210 : if (selectedScores.size() >= 5) {
51 : /*
52 : "To prevent the best first search from exploring the entire
53 : feature subset search space, a stopping criterion is imposed.
54 : The search will terminate if five consecutive fully expanded
55 : subsets show no improvement over the current best subset."
56 : as stated in Mark A.Hall Thesis
57 : */
58 76 : double item_ant = std::numeric_limits<double>::lowest();
59 76 : int num = 0;
60 76 : std::vector<double> lastFive(selectedScores.end() - 5, selectedScores.end());
61 304 : for (auto item : lastFive) {
62 266 : if (item_ant == std::numeric_limits<double>::lowest()) {
63 76 : item_ant = item;
64 : }
65 266 : if (item > item_ant) {
66 38 : break;
67 : } else {
68 228 : num++;
69 228 : item_ant = item;
70 : }
71 : }
72 76 : if (num == 5) {
73 38 : return false;
74 : }
75 76 : }
76 172 : return true;
77 : }
78 : }
|