Files
BayesNet/html/bayesnet/network/Node.cc.gcov.html
2024-05-06 17:56:00 +02:00

22 KiB

<html lang="en"> <head> </head>
LCOV - code coverage report
Current view: top level - bayesnet/network - Node.cc (source / functions) Coverage Total Hit
Test: BayesNet Coverage Report Lines: 100.0 % 88 88
Test Date: 2024-05-06 17:54:04 Functions: 100.0 % 20 20
Legend: Lines: hit not hit

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include "Node.h"
       8              : 
       9              : namespace bayesnet {
      10              : 
      11        31339 :     Node::Node(const std::string& name)
      12        31339 :         : name(name)
      13              :     {
      14        31339 :     }
      15            9 :     void Node::clear()
      16              :     {
      17            9 :         parents.clear();
      18            9 :         children.clear();
      19            9 :         cpTable = torch::Tensor();
      20            9 :         dimensions.clear();
      21            9 :         numStates = 0;
      22            9 :     }
      23    150429643 :     std::string Node::getName() const
      24              :     {
      25    150429643 :         return name;
      26              :     }
      27        59262 :     void Node::addParent(Node* parent)
      28              :     {
      29        59262 :         parents.push_back(parent);
      30        59262 :     }
      31           17 :     void Node::removeParent(Node* parent)
      32              :     {
      33           17 :         parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
      34           17 :     }
      35           17 :     void Node::removeChild(Node* child)
      36              :     {
      37           17 :         children.erase(std::remove(children.begin(), children.end(), child), children.end());
      38           17 :     }
      39        59235 :     void Node::addChild(Node* child)
      40              :     {
      41        59235 :         children.push_back(child);
      42        59235 :     }
      43         5087 :     std::vector<Node*>& Node::getParents()
      44              :     {
      45         5087 :         return parents;
      46              :     }
      47        77571 :     std::vector<Node*>& Node::getChildren()
      48              :     {
      49        77571 :         return children;
      50              :     }
      51        64124 :     int Node::getNumStates() const
      52              :     {
      53        64124 :         return numStates;
      54              :     }
      55        32864 :     void Node::setNumStates(int numStates)
      56              :     {
      57        32864 :         this->numStates = numStates;
      58        32864 :     }
      59          429 :     torch::Tensor& Node::getCPT()
      60              :     {
      61          429 :         return cpTable;
      62              :     }
      63              :     /*
      64              :      The MinFill criterion is a heuristic for variable elimination.
      65              :      The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
      66              :      This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
      67              :      The variable with the minimum number of edges is chosen.
      68              :      Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
      69              :     */
      70           45 :     unsigned Node::minFill()
      71              :     {
      72           45 :         std::unordered_set<std::string> neighbors;
      73          117 :         for (auto child : children) {
      74           72 :             neighbors.emplace(child->getName());
      75              :         }
      76          108 :         for (auto parent : parents) {
      77           63 :             neighbors.emplace(parent->getName());
      78              :         }
      79           45 :         auto source = std::vector<std::string>(neighbors.begin(), neighbors.end());
      80           90 :         return combinations(source).size();
      81           45 :     }
      82           45 :     std::vector<std::pair<std::string, std::string>> Node::combinations(const std::vector<std::string>& source)
      83              :     {
      84           45 :         std::vector<std::pair<std::string, std::string>> result;
      85          180 :         for (int i = 0; i < source.size(); ++i) {
      86          135 :             std::string temp = source[i];
      87          279 :             for (int j = i + 1; j < source.size(); ++j) {
      88          144 :                 result.push_back({ temp, source[j] });
      89              :             }
      90          135 :         }
      91           90 :         return result;
      92           45 :     }
      93        32894 :     void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double laplaceSmoothing, const torch::Tensor& weights)
      94              :     {
      95        32894 :         dimensions.clear();
      96              :         // Get dimensions of the CPT
      97        32894 :         dimensions.push_back(numStates);
      98        94914 :         transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
      99              :         // Create a tensor of zeros with the dimensions of the CPT
     100        32894 :         cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
     101              :         // Fill table with counts
     102        32894 :         auto pos = find(features.begin(), features.end(), name);
     103        32894 :         if (pos == features.end()) {
     104            8 :             throw std::logic_error("Feature " + name + " not found in dataset");
     105              :         }
     106        32886 :         int name_index = pos - features.begin();
     107     11221522 :         for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
     108     11188649 :             c10::List<c10::optional<at::Tensor>> coordinates;
     109     33565947 :             coordinates.push_back(dataset.index({ name_index, n_sample }));
     110     32200749 :             for (auto parent : parents) {
     111     21012113 :                 pos = find(features.begin(), features.end(), parent->getName());
     112     21012113 :                 if (pos == features.end()) {
     113           13 :                     throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
     114              :                 }
     115     21012100 :                 int parent_index = pos - features.begin();
     116     63036300 :                 coordinates.push_back(dataset.index({ parent_index, n_sample }));
     117              :             }
     118              :             // Increment the count of the corresponding coordinate
     119     22377272 :             cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
     120     11188649 :         }
     121              :         // Normalize the counts
     122        32873 :         cpTable = cpTable / cpTable.sum(0);
     123     43422258 :     }
     124     69151761 :     float Node::getFactorValue(std::map<std::string, int>& evidence)
     125              :     {
     126     69151761 :         c10::List<c10::optional<at::Tensor>> coordinates;
     127              :         // following predetermined order of indices in the cpTable (see Node.h)
     128     69151761 :         coordinates.push_back(at::tensor(evidence[name]));
     129    198453273 :         transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); });
     130    138303522 :         return cpTable.index({ coordinates }).item<float>();
     131     69151761 :     }
     132          732 :     std::vector<std::string> Node::graph(const std::string& className)
     133              :     {
     134          732 :         auto output = std::vector<std::string>();
     135          732 :         auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
     136          732 :         output.push_back(name + " [shape=circle" + suffix + "] \n");
     137         1840 :         transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
     138         1464 :         return output;
     139          732 :     }
     140              : }
        

Generated by: LCOV version 2.0-1

</html>