LCOV - code coverage report
Current view: top level - bayesnet/classifiers - TAN.cc (source / functions) Coverage Total Hit
Test: BayesNet Coverage Report Lines: 100.0 % 23 23
Test Date: 2024-05-06 17:54:04 Functions: 100.0 % 4 4
Legend: Lines: hit not hit

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include "TAN.h"
       8              : 
       9              : namespace bayesnet {
      10          188 :     TAN::TAN() : Classifier(Network()) {}
      11              : 
      12           52 :     void TAN::buildModel(const torch::Tensor& weights)
      13              :     {
      14              :         // 0. Add all nodes to the model
      15           52 :         addNodes();
      16              :         // 1. Compute mutual information between each feature and the class and set the root node
      17              :         // as the highest mutual information with the class
      18           52 :         auto mi = std::vector <std::pair<int, float >>();
      19          156 :         torch::Tensor class_dataset = dataset.index({ -1, "..." });
      20          356 :         for (int i = 0; i < static_cast<int>(features.size()); ++i) {
      21          912 :             torch::Tensor feature_dataset = dataset.index({ i, "..." });
      22          304 :             auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset, weights);
      23          304 :             mi.push_back({ i, mi_value });
      24          304 :         }
      25          700 :         sort(mi.begin(), mi.end(), [](const auto& left, const auto& right) {return left.second < right.second;});
      26           52 :         auto root = mi[mi.size() - 1].first;
      27              :         // 2. Compute mutual information between each feature and the class
      28           52 :         auto weights_matrix = metrics.conditionalEdge(weights);
      29              :         // 3. Compute the maximum spanning tree
      30           52 :         auto mst = metrics.maximumSpanningTree(features, weights_matrix, root);
      31              :         // 4. Add edges from the maximum spanning tree to the model
      32          304 :         for (auto i = 0; i < mst.size(); ++i) {
      33          252 :             auto [from, to] = mst[i];
      34          252 :             model.addEdge(features[from], features[to]);
      35              :         }
      36              :         // 5. Add edges from the class to all features
      37          356 :         for (auto feature : features) {
      38          304 :             model.addEdge(className, feature);
      39          304 :         }
      40          408 :     }
      41            8 :     std::vector<std::string> TAN::graph(const std::string& title) const
      42              :     {
      43            8 :         return model.graph(title);
      44              :     }
      45              : }
        

Generated by: LCOV version 2.0-1