LCOV - code coverage report
Current view: top level - bayesnet/network - Node.cc (source / functions) Coverage Total Hit
Test: coverage.info Lines: 95.5 % 88 84
Test Date: 2024-04-29 20:48:03 Functions: 100.0 % 20 20

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include "Node.h"
       8              : 
       9              : namespace bayesnet {
      10              : 
      11       116977 :     Node::Node(const std::string& name)
      12       116977 :         : name(name), numStates(0), cpTable(torch::Tensor()), parents(std::vector<Node*>()), children(std::vector<Node*>())
      13              :     {
      14       116977 :     }
      15           11 :     void Node::clear()
      16              :     {
      17           11 :         parents.clear();
      18           11 :         children.clear();
      19           11 :         cpTable = torch::Tensor();
      20           11 :         dimensions.clear();
      21           11 :         numStates = 0;
      22           11 :     }
      23    159096442 :     std::string Node::getName() const
      24              :     {
      25    159096442 :         return name;
      26              :     }
      27       224331 :     void Node::addParent(Node* parent)
      28              :     {
      29       224331 :         parents.push_back(parent);
      30       224331 :     }
      31           33 :     void Node::removeParent(Node* parent)
      32              :     {
      33           33 :         parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
      34           33 :     }
      35           33 :     void Node::removeChild(Node* child)
      36              :     {
      37           33 :         children.erase(std::remove(children.begin(), children.end(), child), children.end());
      38           33 :     }
      39       224353 :     void Node::addChild(Node* child)
      40              :     {
      41       224353 :         children.push_back(child);
      42       224353 :     }
      43        13948 :     std::vector<Node*>& Node::getParents()
      44              :     {
      45        13948 :         return parents;
      46              :     }
      47       306501 :     std::vector<Node*>& Node::getChildren()
      48              :     {
      49       306501 :         return children;
      50              :     }
      51       236412 :     int Node::getNumStates() const
      52              :     {
      53       236412 :         return numStates;
      54              :     }
      55       121388 :     void Node::setNumStates(int numStates)
      56              :     {
      57       121388 :         this->numStates = numStates;
      58       121388 :     }
      59         1155 :     torch::Tensor& Node::getCPT()
      60              :     {
      61         1155 :         return cpTable;
      62              :     }
      63              :     /*
      64              :      The MinFill criterion is a heuristic for variable elimination.
      65              :      The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
      66              :      This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
      67              :      The variable with the minimum number of edges is chosen.
      68              :      Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
      69              :     */
      70           55 :     unsigned Node::minFill()
      71              :     {
      72           55 :         std::unordered_set<std::string> neighbors;
      73          143 :         for (auto child : children) {
      74           88 :             neighbors.emplace(child->getName());
      75              :         }
      76          132 :         for (auto parent : parents) {
      77           77 :             neighbors.emplace(parent->getName());
      78              :         }
      79           55 :         auto source = std::vector<std::string>(neighbors.begin(), neighbors.end());
      80          110 :         return combinations(source).size();
      81           55 :     }
      82           55 :     std::vector<std::pair<std::string, std::string>> Node::combinations(const std::vector<std::string>& source)
      83              :     {
      84           55 :         std::vector<std::pair<std::string, std::string>> result;
      85          220 :         for (int i = 0; i < source.size(); ++i) {
      86          165 :             std::string temp = source[i];
      87          341 :             for (int j = i + 1; j < source.size(); ++j) {
      88          176 :                 result.push_back({ temp, source[j] });
      89              :             }
      90          165 :         }
      91           55 :         return result;
      92            0 :     }
      93       121388 :     void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double laplaceSmoothing, const torch::Tensor& weights)
      94              :     {
      95       121388 :         dimensions.clear();
      96              :         // Get dimensions of the CPT
      97       121388 :         dimensions.push_back(numStates);
      98       353397 :         transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
      99              : 
     100              :         // Create a tensor of zeros with the dimensions of the CPT
     101       121388 :         cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
     102              :         // Fill table with counts
     103       121388 :         auto pos = find(features.begin(), features.end(), name);
     104       121388 :         if (pos == features.end()) {
     105            0 :             throw std::logic_error("Feature " + name + " not found in dataset");
     106              :         }
     107       121388 :         int name_index = pos - features.begin();
     108     21284350 :         for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
     109     21162962 :             c10::List<c10::optional<at::Tensor>> coordinates;
     110     63488886 :             coordinates.push_back(dataset.index({ name_index, n_sample }));
     111     60665104 :             for (auto parent : parents) {
     112     39502142 :                 pos = find(features.begin(), features.end(), parent->getName());
     113     39502142 :                 if (pos == features.end()) {
     114            0 :                     throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
     115              :                 }
     116     39502142 :                 int parent_index = pos - features.begin();
     117    118506426 :                 coordinates.push_back(dataset.index({ parent_index, n_sample }));
     118              :             }
     119              :             // Increment the count of the corresponding coordinate
     120     42325924 :             cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
     121     21162962 :         }
     122              :         // Normalize the counts
     123       121388 :         cpTable = cpTable / cpTable.sum(0);
     124     81949454 :     }
     125     67302838 :     float Node::getFactorValue(std::map<std::string, int>& evidence)
     126              :     {
     127     67302838 :         c10::List<c10::optional<at::Tensor>> coordinates;
     128              :         // following predetermined order of indices in the cpTable (see Node.h)
     129     67302838 :         coordinates.push_back(at::tensor(evidence[name]));
     130    186413412 :         transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); });
     131    134605676 :         return cpTable.index({ coordinates }).item<float>();
     132     67302838 :     }
     133         1683 :     std::vector<std::string> Node::graph(const std::string& className)
     134              :     {
     135         1683 :         auto output = std::vector<std::string>();
     136         1683 :         auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
     137         1683 :         output.push_back(name + " [shape=circle" + suffix + "] \n");
     138         4334 :         transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
     139         1683 :         return output;
     140            0 :     }
     141              : }
        

Generated by: LCOV version 2.0-1