Files
BayesNet/html/bayesnet/network/Node.cc.gcov.html

22 KiB

<html lang="en"> <head> </head>
LCOV - code coverage report
Current view: top level - bayesnet/network - Node.cc (source / functions) Coverage Total Hit
Test: coverage.info Lines: 95.5 % 88 84
Test Date: 2024-04-30 20:26:57 Functions: 100.0 % 20 20

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include "Node.h"
       8              : 
       9              : namespace bayesnet {
      10              : 
      11        13392 :     Node::Node(const std::string& name)
      12        13392 :         : name(name), numStates(0), cpTable(torch::Tensor()), parents(std::vector<Node*>()), children(std::vector<Node*>())
      13              :     {
      14        13392 :     }
      15            2 :     void Node::clear()
      16              :     {
      17            2 :         parents.clear();
      18            2 :         children.clear();
      19            2 :         cpTable = torch::Tensor();
      20            2 :         dimensions.clear();
      21            2 :         numStates = 0;
      22            2 :     }
      23     68402672 :     std::string Node::getName() const
      24              :     {
      25     68402672 :         return name;
      26              :     }
      27        24964 :     void Node::addParent(Node* parent)
      28              :     {
      29        24964 :         parents.push_back(parent);
      30        24964 :     }
      31            6 :     void Node::removeParent(Node* parent)
      32              :     {
      33            6 :         parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
      34            6 :     }
      35            6 :     void Node::removeChild(Node* child)
      36              :     {
      37            6 :         children.erase(std::remove(children.begin(), children.end(), child), children.end());
      38            6 :     }
      39        24968 :     void Node::addChild(Node* child)
      40              :     {
      41        24968 :         children.push_back(child);
      42        24968 :     }
      43         2536 :     std::vector<Node*>& Node::getParents()
      44              :     {
      45         2536 :         return parents;
      46              :     }
      47        33428 :     std::vector<Node*>& Node::getChildren()
      48              :     {
      49        33428 :         return children;
      50              :     }
      51        27184 :     int Node::getNumStates() const
      52              :     {
      53        27184 :         return numStates;
      54              :     }
      55        14194 :     void Node::setNumStates(int numStates)
      56              :     {
      57        14194 :         this->numStates = numStates;
      58        14194 :     }
      59          210 :     torch::Tensor& Node::getCPT()
      60              :     {
      61          210 :         return cpTable;
      62              :     }
      63              :     /*
      64              :      The MinFill criterion is a heuristic for variable elimination.
      65              :      The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
      66              :      This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
      67              :      The variable with the minimum number of edges is chosen.
      68              :      Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
      69              :     */
      70           10 :     unsigned Node::minFill()
      71              :     {
      72           10 :         std::unordered_set<std::string> neighbors;
      73           26 :         for (auto child : children) {
      74           16 :             neighbors.emplace(child->getName());
      75              :         }
      76           24 :         for (auto parent : parents) {
      77           14 :             neighbors.emplace(parent->getName());
      78              :         }
      79           10 :         auto source = std::vector<std::string>(neighbors.begin(), neighbors.end());
      80           20 :         return combinations(source).size();
      81           10 :     }
      82           10 :     std::vector<std::pair<std::string, std::string>> Node::combinations(const std::vector<std::string>& source)
      83              :     {
      84           10 :         std::vector<std::pair<std::string, std::string>> result;
      85           40 :         for (int i = 0; i < source.size(); ++i) {
      86           30 :             std::string temp = source[i];
      87           62 :             for (int j = i + 1; j < source.size(); ++j) {
      88           32 :                 result.push_back({ temp, source[j] });
      89              :             }
      90           30 :         }
      91           10 :         return result;
      92            0 :     }
      93        14194 :     void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double laplaceSmoothing, const torch::Tensor& weights)
      94              :     {
      95        14194 :         dimensions.clear();
      96              :         // Get dimensions of the CPT
      97        14194 :         dimensions.push_back(numStates);
      98        40554 :         transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
      99              : 
     100              :         // Create a tensor of zeros with the dimensions of the CPT
     101        14194 :         cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
     102              :         // Fill table with counts
     103        14194 :         auto pos = find(features.begin(), features.end(), name);
     104        14194 :         if (pos == features.end()) {
     105            0 :             throw std::logic_error("Feature " + name + " not found in dataset");
     106              :         }
     107        14194 :         int name_index = pos - features.begin();
     108      5137182 :         for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
     109      5122988 :             c10::List<c10::optional<at::Tensor>> coordinates;
     110     15368964 :             coordinates.push_back(dataset.index({ name_index, n_sample }));
     111     14647408 :             for (auto parent : parents) {
     112      9524420 :                 pos = find(features.begin(), features.end(), parent->getName());
     113      9524420 :                 if (pos == features.end()) {
     114            0 :                     throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
     115              :                 }
     116      9524420 :                 int parent_index = pos - features.begin();
     117     28573260 :                 coordinates.push_back(dataset.index({ parent_index, n_sample }));
     118              :             }
     119              :             // Increment the count of the corresponding coordinate
     120     10245976 :             cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
     121      5122988 :         }
     122              :         // Normalize the counts
     123        14194 :         cpTable = cpTable / cpTable.sum(0);
     124     19784590 :     }
     125     31802436 :     float Node::getFactorValue(std::map<std::string, int>& evidence)
     126              :     {
     127     31802436 :         c10::List<c10::optional<at::Tensor>> coordinates;
     128              :         // following predetermined order of indices in the cpTable (see Node.h)
     129     31802436 :         coordinates.push_back(at::tensor(evidence[name]));
     130     90628968 :         transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); });
     131     63604872 :         return cpTable.index({ coordinates }).item<float>();
     132     31802436 :     }
     133          306 :     std::vector<std::string> Node::graph(const std::string& className)
     134              :     {
     135          306 :         auto output = std::vector<std::string>();
     136          306 :         auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
     137          306 :         output.push_back(name + " [shape=circle" + suffix + "] \n");
     138          788 :         transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
     139          306 :         return output;
     140            0 :     }
     141              : }
        

Generated by: LCOV version 2.0-1

</html>