Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #ifndef BAYESNET_METRICS_H
8 : #define BAYESNET_METRICS_H
9 : #include <vector>
10 : #include <string>
11 : #include <torch/torch.h>
12 : namespace bayesnet {
13 : class Metrics {
14 : public:
15 4750 : Metrics() = default;
16 : Metrics(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
17 : Metrics(const std::vector<std::vector<int>>& vsamples, const std::vector<int>& labels, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
18 : std::vector<int> SelectKBestWeighted(const torch::Tensor& weights, bool ascending = false, unsigned k = 0);
19 : std::vector<double> getScoresKBest() const;
20 : double mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
21 : std::vector<float> conditionalEdgeWeights(std::vector<float>& weights); // To use in Python
22 : torch::Tensor conditionalEdge(const torch::Tensor& weights);
23 : std::vector<std::pair<int, int>> maximumSpanningTree(const std::vector<std::string>& features, const torch::Tensor& weights, const int root);
24 : protected:
25 : torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector
26 : std::string className;
27 : double entropy(const torch::Tensor& feature, const torch::Tensor& weights);
28 : std::vector<std::string> features;
29 : template <class T>
30 2225 : std::vector<std::pair<T, T>> doCombinations(const std::vector<T>& source)
31 : {
32 2225 : std::vector<std::pair<T, T>> result;
33 11660 : for (int i = 0; i < source.size(); ++i) {
34 9435 : T temp = source[i];
35 29517 : for (int j = i + 1; j < source.size(); ++j) {
36 20082 : result.push_back({ temp, source[j] });
37 : }
38 : }
39 2225 : return result;
40 0 : }
41 : template <class T>
42 116 : T pop_first(std::vector<T>& v)
43 : {
44 116 : T temp = v[0];
45 116 : v.erase(v.begin());
46 116 : return temp;
47 : }
48 : private:
49 : int classNumStates = 0;
50 : std::vector<double> scoresKBest;
51 : std::vector<int> featuresKBest; // sorted indices of the features
52 : double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
53 : };
54 : }
55 : #endif
|