Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef CFS_H
8 : #define CFS_H
9 : #include <torch/torch.h>
10 : #include <vector>
11 : #include "bayesnet/feature_selection/FeatureSelect.h"
12 : namespace bayesnet {
13 : class CFS : public FeatureSelect {
14 : public:
15 : // dataset is a n+1xm tensor of integers where dataset[-1] is the y std::vector
16 26 : CFS(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int maxFeatures, const int classNumStates, const torch::Tensor& weights) :
17 26 : FeatureSelect(samples, features, className, maxFeatures, classNumStates, weights)
18 : {
19 26 : }
20 88 : virtual ~CFS() {};
21 : void fit() override;
22 : private:
23 : bool computeContinueCondition(const std::vector<int>& featureOrder);
24 : };
25 : }
26 : #endif
|