Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include <limits>
8 : #include "bayesnet/utils/bayesnetUtils.h"
9 : #include "IWSS.h"
10 : namespace bayesnet {
11 62 : IWSS::IWSS(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int maxFeatures, const int classNumStates, const torch::Tensor& weights, const double threshold) :
12 62 : FeatureSelect(samples, features, className, maxFeatures, classNumStates, weights), threshold(threshold)
13 : {
14 62 : if (threshold < 0 || threshold > .5) {
15 28 : throw std::invalid_argument("Threshold has to be in [0, 0.5]");
16 : }
17 62 : }
18 34 : void IWSS::fit()
19 : {
20 34 : initialize();
21 34 : computeSuLabels();
22 34 : auto featureOrder = argsort(suLabels); // sort descending order
23 34 : auto featureOrderCopy = featureOrder;
24 : // Add first and second features to result
25 : // First with its own score
26 34 : auto first_feature = pop_first(featureOrderCopy);
27 34 : selectedFeatures.push_back(first_feature);
28 34 : selectedScores.push_back(suLabels.at(first_feature));
29 : // Second with the score of the candidates
30 34 : selectedFeatures.push_back(pop_first(featureOrderCopy));
31 34 : auto merit = computeMeritCFS();
32 34 : selectedScores.push_back(merit);
33 116 : for (const auto feature : featureOrderCopy) {
34 116 : selectedFeatures.push_back(feature);
35 : // Compute merit with selectedFeatures
36 116 : auto meritNew = computeMeritCFS();
37 116 : double delta = merit != 0.0 ? std::abs(merit - meritNew) / merit : 0.0;
38 116 : if (meritNew > merit || delta < threshold) {
39 82 : if (meritNew > merit) {
40 0 : merit = meritNew;
41 : }
42 82 : selectedScores.push_back(meritNew);
43 : } else {
44 34 : selectedFeatures.pop_back();
45 34 : break;
46 : }
47 82 : if (selectedFeatures.size() == maxFeatures) {
48 0 : break;
49 : }
50 : }
51 34 : fitted = true;
52 34 : }
53 : }
|