Files
BayesNet/html/bayesnet/classifiers/TAN.cc.gcov.html

8.9 KiB

<html lang="en"> <head> </head>
LCOV - code coverage report
Current view: top level - bayesnet/classifiers - TAN.cc (source / functions) Coverage Total Hit
Test: coverage.info Lines: 100.0 % 23 23
Test Date: 2024-04-30 20:26:57 Functions: 100.0 % 4 4

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include "TAN.h"
       8              : 
       9              : namespace bayesnet {
      10           94 :     TAN::TAN() : Classifier(Network()) {}
      11              : 
      12           26 :     void TAN::buildModel(const torch::Tensor& weights)
      13              :     {
      14              :         // 0. Add all nodes to the model
      15           26 :         addNodes();
      16              :         // 1. Compute mutual information between each feature and the class and set the root node
      17              :         // as the highest mutual information with the class
      18           26 :         auto mi = std::vector <std::pair<int, float >>();
      19           78 :         torch::Tensor class_dataset = dataset.index({ -1, "..." });
      20          178 :         for (int i = 0; i < static_cast<int>(features.size()); ++i) {
      21          456 :             torch::Tensor feature_dataset = dataset.index({ i, "..." });
      22          152 :             auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset, weights);
      23          152 :             mi.push_back({ i, mi_value });
      24          152 :         }
      25          350 :         sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
      26           26 :         auto root = mi[mi.size() - 1].first;
      27              :         // 2. Compute mutual information between each feature and the class
      28           26 :         auto weights_matrix = metrics.conditionalEdge(weights);
      29              :         // 3. Compute the maximum spanning tree
      30           26 :         auto mst = metrics.maximumSpanningTree(features, weights_matrix, root);
      31              :         // 4. Add edges from the maximum spanning tree to the model
      32          152 :         for (auto i = 0; i < mst.size(); ++i) {
      33          126 :             auto [from, to] = mst[i];
      34          126 :             model.addEdge(features[from], features[to]);
      35              :         }
      36              :         // 5. Add edges from the class to all features
      37          178 :         for (auto feature : features) {
      38          152 :             model.addEdge(className, feature);
      39          152 :         }
      40          204 :     }
      41            4 :     std::vector<std::string> TAN::graph(const std::string& title) const
      42              :     {
      43            4 :         return model.graph(title);
      44              :     }
      45              : }
        

Generated by: LCOV version 2.0-1

</html>