Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include <limits>
8 : #include "bayesnet/utils/bayesnetUtils.h"
9 : #include "IWSS.h"
10 : namespace bayesnet {
11 102 : IWSS::IWSS(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int maxFeatures, const int classNumStates, const torch::Tensor& weights, const double threshold) :
12 102 : FeatureSelect(samples, features, className, maxFeatures, classNumStates, weights), threshold(threshold)
13 : {
14 102 : if (threshold < 0 || threshold > .5) {
15 44 : throw std::invalid_argument("Threshold has to be in [0, 0.5]");
16 : }
17 102 : }
18 58 : void IWSS::fit()
19 : {
20 58 : initialize();
21 58 : computeSuLabels();
22 58 : auto featureOrder = argsort(suLabels); // sort descending order
23 58 : auto featureOrderCopy = featureOrder;
24 : // Add first and second features to result
25 : // First with its own score
26 58 : auto first_feature = pop_first(featureOrderCopy);
27 58 : selectedFeatures.push_back(first_feature);
28 58 : selectedScores.push_back(suLabels.at(first_feature));
29 : // Second with the score of the candidates
30 58 : selectedFeatures.push_back(pop_first(featureOrderCopy));
31 58 : auto merit = computeMeritCFS();
32 58 : selectedScores.push_back(merit);
33 192 : for (const auto feature : featureOrderCopy) {
34 192 : selectedFeatures.push_back(feature);
35 : // Compute merit with selectedFeatures
36 192 : auto meritNew = computeMeritCFS();
37 192 : double delta = merit != 0.0 ? std::abs(merit - meritNew) / merit : 0.0;
38 192 : if (meritNew > merit || delta < threshold) {
39 134 : if (meritNew > merit) {
40 0 : merit = meritNew;
41 : }
42 134 : selectedScores.push_back(meritNew);
43 : } else {
44 58 : selectedFeatures.pop_back();
45 58 : break;
46 : }
47 134 : if (selectedFeatures.size() == maxFeatures) {
48 0 : break;
49 : }
50 : }
51 58 : fitted = true;
52 58 : }
53 : }
|