BayesNet/bayesnet/network/Node.cc

143 lines
5.5 KiB
C++
Raw Permalink Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2023-06-29 20:00:41 +00:00
#include "Node.h"
namespace bayesnet {
2023-08-05 12:40:42 +00:00
Node::Node(const std::string& name)
: name(name)
2023-06-29 20:00:41 +00:00
{
}
2023-07-25 23:39:01 +00:00
void Node::clear()
{
parents.clear();
children.clear();
cpTable = torch::Tensor();
dimensions.clear();
numStates = 0;
}
2023-11-08 17:45:35 +00:00
std::string Node::getName() const
2023-06-29 20:00:41 +00:00
{
return name;
}
void Node::addParent(Node* parent)
{
parents.push_back(parent);
}
2023-06-29 21:53:33 +00:00
void Node::removeParent(Node* parent)
{
parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
}
void Node::removeChild(Node* child)
{
children.erase(std::remove(children.begin(), children.end(), child), children.end());
}
2023-06-29 20:00:41 +00:00
void Node::addChild(Node* child)
{
children.push_back(child);
}
2023-11-08 17:45:35 +00:00
std::vector<Node*>& Node::getParents()
2023-06-29 20:00:41 +00:00
{
return parents;
}
2023-11-08 17:45:35 +00:00
std::vector<Node*>& Node::getChildren()
2023-06-29 20:00:41 +00:00
{
return children;
}
int Node::getNumStates() const
{
return numStates;
}
void Node::setNumStates(int numStates)
{
this->numStates = numStates;
}
2023-06-30 19:24:12 +00:00
torch::Tensor& Node::getCPT()
2023-06-29 20:00:41 +00:00
{
2023-07-05 16:38:54 +00:00
return cpTable;
2023-06-29 20:00:41 +00:00
}
/*
The MinFill criterion is a heuristic for variable elimination.
The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
The variable with the minimum number of edges is chosen.
Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
*/
unsigned Node::minFill()
{
2023-11-08 17:45:35 +00:00
std::unordered_set<std::string> neighbors;
for (auto child : children) {
neighbors.emplace(child->getName());
}
for (auto parent : parents) {
neighbors.emplace(parent->getName());
}
2023-11-08 17:45:35 +00:00
auto source = std::vector<std::string>(neighbors.begin(), neighbors.end());
return combinations(source).size();
}
2023-11-08 17:45:35 +00:00
std::vector<std::pair<std::string, std::string>> Node::combinations(const std::vector<std::string>& source)
{
2023-11-08 17:45:35 +00:00
std::vector<std::pair<std::string, std::string>> result;
for (int i = 0; i < source.size(); ++i) {
2023-11-08 17:45:35 +00:00
std::string temp = source[i];
for (int j = i + 1; j < source.size(); ++j) {
result.push_back({ temp, source[j] });
}
}
return result;
}
2024-06-09 15:19:38 +00:00
void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double smoothing, const torch::Tensor& weights)
{
2023-08-05 12:40:42 +00:00
dimensions.clear();
2023-07-05 16:38:54 +00:00
// Get dimensions of the CPT
dimensions.push_back(numStates);
2023-07-29 22:04:18 +00:00
transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
2023-07-05 16:38:54 +00:00
// Create a tensor of zeros with the dimensions of the CPT
2024-09-18 10:13:11 +00:00
cpTable = torch::zeros(dimensions, torch::kDouble).to(device) + smoothing;
2023-07-05 16:38:54 +00:00
// Fill table with counts
auto pos = find(features.begin(), features.end(), name);
if (pos == features.end()) {
2023-11-08 17:45:35 +00:00
throw std::logic_error("Feature " + name + " not found in dataset");
}
int name_index = pos - features.begin();
c10::List<c10::optional<at::Tensor>> coordinates;
for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
coordinates.clear();
auto sample = dataset.index({ "...", n_sample });
coordinates.push_back(sample[name_index]);
for (auto parent : parents) {
pos = find(features.begin(), features.end(), parent->getName());
if (pos == features.end()) {
2023-11-08 17:45:35 +00:00
throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
}
int parent_index = pos - features.begin();
coordinates.push_back(sample[parent_index]);
}
2023-07-05 16:38:54 +00:00
// Increment the count of the corresponding coordinate
2024-07-08 11:27:55 +00:00
cpTable.index_put_({ coordinates }, weights.index({ n_sample }), true);
2023-07-05 16:38:54 +00:00
}
// Normalize the counts
2024-07-04 16:52:41 +00:00
// Divide each row by the sum of the row
2023-07-05 16:38:54 +00:00
cpTable = cpTable / cpTable.sum(0);
}
2024-07-08 11:27:55 +00:00
double Node::getFactorValue(std::map<std::string, int>& evidence)
2023-07-05 16:38:54 +00:00
{
2023-09-02 11:58:12 +00:00
c10::List<c10::optional<at::Tensor>> coordinates;
2023-07-05 16:38:54 +00:00
// following predetermined order of indices in the cpTable (see Node.h)
2023-09-02 11:58:12 +00:00
coordinates.push_back(at::tensor(evidence[name]));
2023-11-08 17:45:35 +00:00
transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); });
2024-07-08 11:27:55 +00:00
return cpTable.index({ coordinates }).item<double>();
}
2023-11-08 17:45:35 +00:00
std::vector<std::string> Node::graph(const std::string& className)
2023-07-15 23:20:47 +00:00
{
2023-11-08 17:45:35 +00:00
auto output = std::vector<std::string>();
2023-07-15 23:20:47 +00:00
auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
output.push_back("\"" + name + "\" [shape=circle" + suffix + "] \n");
transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return "\"" + name + "\" -> \"" + child->getName() + "\""; });
2023-07-15 23:20:47 +00:00
return output;
}
2023-06-29 20:00:41 +00:00
}