LCOV - code coverage report
Current view: top level - bayesnet/utils - BayesMetrics.h (source / functions) Coverage Total Hit
Test: BayesNet Coverage Report Lines: 100.0 % 13 13
Test Date: 2024-05-06 17:54:04 Functions: 100.0 % 4 4
Legend: Lines: hit not hit

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #ifndef BAYESNET_METRICS_H
       8              : #define BAYESNET_METRICS_H
       9              : #include <vector>
      10              : #include <string>
      11              : #include <torch/torch.h>
      12              : namespace bayesnet {
      13              :     class Metrics {
      14              :     public:
      15         2240 :         Metrics() = default;
      16              :         Metrics(const torch::Tensor& samples, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
      17              :         Metrics(const std::vector<std::vector<int>>& vsamples, const std::vector<int>& labels, const std::vector<std::string>& features, const std::string& className, const int classNumStates);
      18              :         std::vector<int> SelectKBestWeighted(const torch::Tensor& weights, bool ascending = false, unsigned k = 0);
      19              :         std::vector<double> getScoresKBest() const;
      20              :         double mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
      21              :         torch::Tensor conditionalEdge(const torch::Tensor& weights);
      22              :         std::vector<std::pair<int, int>> maximumSpanningTree(const std::vector<std::string>& features, const torch::Tensor& weights, const int root);
      23              :     protected:
      24              :         torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector
      25              :         std::string className;
      26              :         double entropy(const torch::Tensor& feature, const torch::Tensor& weights);
      27              :         std::vector<std::string> features;
      28              :         template <class T>
      29         1391 :         std::vector<std::pair<T, T>> doCombinations(const std::vector<T>& source)
      30              :         {
      31         1391 :             std::vector<std::pair<T, T>> result;
      32         6981 :             for (int i = 0; i < source.size(); ++i) {
      33         5590 :                 T temp = source[i];
      34        16517 :                 for (int j = i + 1; j < source.size(); ++j) {
      35        10927 :                     result.push_back({ temp, source[j] });
      36              :                 }
      37              :             }
      38         2782 :             return result;
      39         1391 :         }
      40              :         template <class T>
      41           94 :         T pop_first(std::vector<T>& v)
      42              :         {
      43           94 :             T temp = v[0];
      44           94 :             v.erase(v.begin());
      45           94 :             return temp;
      46              :         }
      47              :     private:
      48              :         int classNumStates = 0;
      49              :         std::vector<double> scoresKBest;
      50              :         std::vector<int> featuresKBest; // sorted indices of the features
      51              :         double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights);
      52              :     };
      53              : }
      54              : #endif
        

Generated by: LCOV version 2.0-1