Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef BAYESNET_METRICS_H
8 : #define BAYESNET_METRICS_H
9 : #include <vector>
10 : #include <string>
11 : #include <torch/torch.h>
12 : namespace bayesnet {
13 : class Metrics {
14 : public:
15 2240 : Metrics() = default;
16 : Metrics(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
17 : Metrics(const std::vector<std::vector<int>>& vsamples, const std::vector<int>& labels, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
18 : std::vector<int> SelectKBestWeighted(const torch::Tensor& weights, bool ascending = false, unsigned k = 0);
19 : std::vector<double> getScoresKBest() const;
20 : double mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
21 : torch::Tensor conditionalEdge(const torch::Tensor& weights);
22 : std::vector<std::pair<int, int>> maximumSpanningTree(const std::vector<std::string>& features, const torch::Tensor& weights, const int root);
23 : protected:
24 : torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector
25 : std::string className;
26 : double entropy(const torch::Tensor& feature, const torch::Tensor& weights);
27 : std::vector<std::string> features;
28 : template <class T>
29 1391 : std::vector<std::pair<T, T>> doCombinations(const std::vector<T>& source)
30 : {
31 1391 : std::vector<std::pair<T, T>> result;
32 6981 : for (int i = 0; i < source.size(); ++i) {
33 5590 : T temp = source[i];
34 16517 : for (int j = i + 1; j < source.size(); ++j) {
35 10927 : result.push_back({ temp, source[j] });
36 : }
37 : }
38 2782 : return result;
39 1391 : }
40 : template <class T>
41 94 : T pop_first(std::vector<T>& v)
42 : {
43 94 : T temp = v[0];
44 94 : v.erase(v.begin());
45 94 : return temp;
46 : }
47 : private:
48 : int classNumStates = 0;
49 : std::vector<double> scoresKBest;
50 : std::vector<int> featuresKBest; // sorted indices of the features
51 : double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
52 : };
53 : }
54 : #endif
|