22 KiB
22 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | ||||||||||||||||||||||
![]() | ||||||||||||||||||||||
|
||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #include "Node.h" 8 : 9 : namespace bayesnet { 10 : 11 40176 : Node::Node(const std::string& name) 12 40176 : : name(name), numStates(0), cpTable(torch::Tensor()), parents(std::vector<Node*>()), children(std::vector<Node*>()) 13 : { 14 40176 : } 15 6 : void Node::clear() 16 : { 17 6 : parents.clear(); 18 6 : children.clear(); 19 6 : cpTable = torch::Tensor(); 20 6 : dimensions.clear(); 21 6 : numStates = 0; 22 6 : } 23 205208016 : std::string Node::getName() const 24 : { 25 205208016 : return name; 26 : } 27 74892 : void Node::addParent(Node* parent) 28 : { 29 74892 : parents.push_back(parent); 30 74892 : } 31 18 : void Node::removeParent(Node* parent) 32 : { 33 18 : parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end()); 34 18 : } 35 18 : void Node::removeChild(Node* child) 36 : { 37 18 : children.erase(std::remove(children.begin(), children.end(), child), children.end()); 38 18 : } 39 74904 : void Node::addChild(Node* child) 40 : { 41 74904 : children.push_back(child); 42 74904 : } 43 7608 : std::vector<Node*>& Node::getParents() 44 : { 45 7608 : return parents; 46 : } 47 100284 : std::vector<Node*>& Node::getChildren() 48 : { 49 100284 : return children; 50 : } 51 81552 : int Node::getNumStates() const 52 : { 53 81552 : return numStates; 54 : } 55 42582 : void Node::setNumStates(int numStates) 56 : { 57 42582 : this->numStates = numStates; 58 42582 : } 59 630 : torch::Tensor& Node::getCPT() 60 : { 61 630 : return cpTable; 62 : } 63 : /* 64 : The MinFill criterion is a heuristic for variable elimination. 65 : The variable that minimizes the number of edges that need to be added to the graph to make it triangulated. 66 : This is done by counting the number of edges that need to be added to the graph if the variable is eliminated. 67 : The variable with the minimum number of edges is chosen. 68 : Here this is done computing the length of the combinations of the node neighbors taken 2 by 2. 69 : */ 70 30 : unsigned Node::minFill() 71 : { 72 30 : std::unordered_set<std::string> neighbors; 73 78 : for (auto child : children) { 74 48 : neighbors.emplace(child->getName()); 75 : } 76 72 : for (auto parent : parents) { 77 42 : neighbors.emplace(parent->getName()); 78 : } 79 30 : auto source = std::vector<std::string>(neighbors.begin(), neighbors.end()); 80 60 : return combinations(source).size(); 81 30 : } 82 30 : std::vector<std::pair<std::string, std::string>> Node::combinations(const std::vector<std::string>& source) 83 : { 84 30 : std::vector<std::pair<std::string, std::string>> result; 85 120 : for (int i = 0; i < source.size(); ++i) { 86 90 : std::string temp = source[i]; 87 186 : for (int j = i + 1; j < source.size(); ++j) { 88 96 : result.push_back({ temp, source[j] }); 89 : } 90 90 : } 91 30 : return result; 92 0 : } 93 42582 : void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double laplaceSmoothing, const torch::Tensor& weights) 94 : { 95 42582 : dimensions.clear(); 96 : // Get dimensions of the CPT 97 42582 : dimensions.push_back(numStates); 98 121662 : transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); }); 99 : 100 : // Create a tensor of zeros with the dimensions of the CPT 101 42582 : cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing; 102 : // Fill table with counts 103 42582 : auto pos = find(features.begin(), features.end(), name); 104 42582 : if (pos == features.end()) { 105 0 : throw std::logic_error("Feature " + name + " not found in dataset"); 106 : } 107 42582 : int name_index = pos - features.begin(); 108 15411546 : for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) { 109 15368964 : c10::List<c10::optional<at::Tensor>> coordinates; 110 46106892 : coordinates.push_back(dataset.index({ name_index, n_sample })); 111 43942224 : for (auto parent : parents) { 112 28573260 : pos = find(features.begin(), features.end(), parent->getName()); 113 28573260 : if (pos == features.end()) { 114 0 : throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset"); 115 : } 116 28573260 : int parent_index = pos - features.begin(); 117 85719780 : coordinates.push_back(dataset.index({ parent_index, n_sample })); 118 : } 119 : // Increment the count of the corresponding coordinate 120 30737928 : cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>()); 121 15368964 : } 122 : // Normalize the counts 123 42582 : cpTable = cpTable / cpTable.sum(0); 124 59353770 : } 125 95407308 : float Node::getFactorValue(std::map<std::string, int>& evidence) 126 : { 127 95407308 : c10::List<c10::optional<at::Tensor>> coordinates; 128 : // following predetermined order of indices in the cpTable (see Node.h) 129 95407308 : coordinates.push_back(at::tensor(evidence[name])); 130 271886904 : transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); }); 131 190814616 : return cpTable.index({ coordinates }).item<float>(); 132 95407308 : } 133 918 : std::vector<std::string> Node::graph(const std::string& className) 134 : { 135 918 : auto output = std::vector<std::string>(); 136 918 : auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : ""; 137 918 : output.push_back(name + " [shape=circle" + suffix + "] \n"); 138 2364 : transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); }); 139 918 : return output; 140 0 : } 141 : } |
![]() |
Generated by: LCOV version 2.0-1 |
</html>