Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef FEATURE_SELECT_H
8 : #define FEATURE_SELECT_H
9 : #include <torch/torch.h>
10 : #include <vector>
11 : #include "bayesnet/utils/BayesMetrics.h"
12 : namespace bayesnet {
13 : class FeatureSelect : public Metrics {
14 : public:
15 : // dataset is a n+1xm tensor of integers where dataset[-1] is the y std::vector
16 : FeatureSelect(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int maxFeatures, const int classNumStates, const torch::Tensor& weights);
17 88 : virtual ~FeatureSelect() {};
18 : virtual void fit() = 0;
19 : std::vector<int> getFeatures() const;
20 : std::vector<double> getScores() const;
21 : protected:
22 : void initialize();
23 : void computeSuLabels();
24 : double computeSuFeatures(const int a, const int b);
25 : double symmetricalUncertainty(int a, int b);
26 : double computeMeritCFS();
27 : const torch::Tensor& weights;
28 : int maxFeatures;
29 : std::vector<int> selectedFeatures;
30 : std::vector<double> selectedScores;
31 : std::vector<double> suLabels;
32 : std::map<std::pair<int, int>, double> suFeatures;
33 : bool fitted = false;
34 : };
35 : }
36 : #endif
|