LCOV - code coverage report
Current view: top level - bayesnet/network - Node.cc (source / functions) Coverage Total Hit
Test: coverage.info Lines: 95.5 % 88 84
Test Date: 2024-04-30 13:59:18 Functions: 100.0 % 20 20

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include "Node.h"
       8              : 
       9              : namespace bayesnet {
      10              : 
      11        40176 :     Node::Node(const std::string& name)
      12        40176 :         : name(name), numStates(0), cpTable(torch::Tensor()), parents(std::vector<Node*>()), children(std::vector<Node*>())
      13              :     {
      14        40176 :     }
      15            6 :     void Node::clear()
      16              :     {
      17            6 :         parents.clear();
      18            6 :         children.clear();
      19            6 :         cpTable = torch::Tensor();
      20            6 :         dimensions.clear();
      21            6 :         numStates = 0;
      22            6 :     }
      23    205208016 :     std::string Node::getName() const
      24              :     {
      25    205208016 :         return name;
      26              :     }
      27        74892 :     void Node::addParent(Node* parent)
      28              :     {
      29        74892 :         parents.push_back(parent);
      30        74892 :     }
      31           18 :     void Node::removeParent(Node* parent)
      32              :     {
      33           18 :         parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
      34           18 :     }
      35           18 :     void Node::removeChild(Node* child)
      36              :     {
      37           18 :         children.erase(std::remove(children.begin(), children.end(), child), children.end());
      38           18 :     }
      39        74904 :     void Node::addChild(Node* child)
      40              :     {
      41        74904 :         children.push_back(child);
      42        74904 :     }
      43         7608 :     std::vector<Node*>& Node::getParents()
      44              :     {
      45         7608 :         return parents;
      46              :     }
      47       100284 :     std::vector<Node*>& Node::getChildren()
      48              :     {
      49       100284 :         return children;
      50              :     }
      51        81552 :     int Node::getNumStates() const
      52              :     {
      53        81552 :         return numStates;
      54              :     }
      55        42582 :     void Node::setNumStates(int numStates)
      56              :     {
      57        42582 :         this->numStates = numStates;
      58        42582 :     }
      59          630 :     torch::Tensor& Node::getCPT()
      60              :     {
      61          630 :         return cpTable;
      62              :     }
      63              :     /*
      64              :      The MinFill criterion is a heuristic for variable elimination.
      65              :      The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
      66              :      This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
      67              :      The variable with the minimum number of edges is chosen.
      68              :      Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
      69              :     */
      70           30 :     unsigned Node::minFill()
      71              :     {
      72           30 :         std::unordered_set<std::string> neighbors;
      73           78 :         for (auto child : children) {
      74           48 :             neighbors.emplace(child->getName());
      75              :         }
      76           72 :         for (auto parent : parents) {
      77           42 :             neighbors.emplace(parent->getName());
      78              :         }
      79           30 :         auto source = std::vector<std::string>(neighbors.begin(), neighbors.end());
      80           60 :         return combinations(source).size();
      81           30 :     }
      82           30 :     std::vector<std::pair<std::string, std::string>> Node::combinations(const std::vector<std::string>& source)
      83              :     {
      84           30 :         std::vector<std::pair<std::string, std::string>> result;
      85          120 :         for (int i = 0; i < source.size(); ++i) {
      86           90 :             std::string temp = source[i];
      87          186 :             for (int j = i + 1; j < source.size(); ++j) {
      88           96 :                 result.push_back({ temp, source[j] });
      89              :             }
      90           90 :         }
      91           30 :         return result;
      92            0 :     }
      93        42582 :     void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double laplaceSmoothing, const torch::Tensor& weights)
      94              :     {
      95        42582 :         dimensions.clear();
      96              :         // Get dimensions of the CPT
      97        42582 :         dimensions.push_back(numStates);
      98       121662 :         transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
      99              : 
     100              :         // Create a tensor of zeros with the dimensions of the CPT
     101        42582 :         cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
     102              :         // Fill table with counts
     103        42582 :         auto pos = find(features.begin(), features.end(), name);
     104        42582 :         if (pos == features.end()) {
     105            0 :             throw std::logic_error("Feature " + name + " not found in dataset");
     106              :         }
     107        42582 :         int name_index = pos - features.begin();
     108     15411546 :         for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
     109     15368964 :             c10::List<c10::optional<at::Tensor>> coordinates;
     110     46106892 :             coordinates.push_back(dataset.index({ name_index, n_sample }));
     111     43942224 :             for (auto parent : parents) {
     112     28573260 :                 pos = find(features.begin(), features.end(), parent->getName());
     113     28573260 :                 if (pos == features.end()) {
     114            0 :                     throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
     115              :                 }
     116     28573260 :                 int parent_index = pos - features.begin();
     117     85719780 :                 coordinates.push_back(dataset.index({ parent_index, n_sample }));
     118              :             }
     119              :             // Increment the count of the corresponding coordinate
     120     30737928 :             cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
     121     15368964 :         }
     122              :         // Normalize the counts
     123        42582 :         cpTable = cpTable / cpTable.sum(0);
     124     59353770 :     }
     125     95407308 :     float Node::getFactorValue(std::map<std::string, int>& evidence)
     126              :     {
     127     95407308 :         c10::List<c10::optional<at::Tensor>> coordinates;
     128              :         // following predetermined order of indices in the cpTable (see Node.h)
     129     95407308 :         coordinates.push_back(at::tensor(evidence[name]));
     130    271886904 :         transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); });
     131    190814616 :         return cpTable.index({ coordinates }).item<float>();
     132     95407308 :     }
     133          918 :     std::vector<std::string> Node::graph(const std::string& className)
     134              :     {
     135          918 :         auto output = std::vector<std::string>();
     136          918 :         auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
     137          918 :         output.push_back(name + " [shape=circle" + suffix + "] \n");
     138         2364 :         transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
     139          918 :         return output;
     140            0 :     }
     141              : }
        

Generated by: LCOV version 2.0-1