// *************************************************************** // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez // SPDX-FileType: SOURCE // SPDX-License-Identifier: MIT // *************************************************************** #ifndef BAYESNET_METRICS_H #define BAYESNET_METRICS_H #include #include #include namespace bayesnet { class Metrics { public: Metrics() = default; Metrics(const torch::Tensor& samples, const std::vector& features, const std::string& className, const int classNumStates); Metrics(const std::vector>& vsamples, const std::vector& labels, const std::vector& features, const std::string& className, const int classNumStates); std::vector SelectKBestWeighted(const torch::Tensor& weights, bool ascending = false, unsigned k = 0); std::vector getScoresKBest() const; double mutualInformation(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights); torch::Tensor conditionalEdge(const torch::Tensor& weights); std::vector> maximumSpanningTree(const std::vector& features, const torch::Tensor& weights, const int root); protected: torch::Tensor samples; // n+1xm torch::Tensor used to fit the model where samples[-1] is the y std::vector std::string className; double entropy(const torch::Tensor& feature, const torch::Tensor& weights); std::vector features; template std::vector> doCombinations(const std::vector& source) { std::vector> result; for (int i = 0; i < source.size(); ++i) { T temp = source[i]; for (int j = i + 1; j < source.size(); ++j) { result.push_back({ temp, source[j] }); } } return result; } template T pop_first(std::vector& v) { T temp = v[0]; v.erase(v.begin()); return temp; } private: int classNumStates = 0; std::vector scoresKBest; std::vector featuresKBest; // sorted indices of the features double conditionalEntropy(const torch::Tensor& firstFeature, const torch::Tensor& secondFeature, const torch::Tensor& weights); }; } #endif