2024-04-11 16:02:49 +00:00
|
|
|
// ***************************************************************
|
|
|
|
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
|
|
|
|
// SPDX-FileType: SOURCE
|
|
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
// ***************************************************************
|
|
|
|
|
2023-06-29 20:00:41 +00:00
|
|
|
#include "Node.h"
|
|
|
|
|
|
|
|
namespace bayesnet {
|
|
|
|
|
2023-08-05 12:40:42 +00:00
|
|
|
Node::Node(const std::string& name)
|
2023-11-08 17:45:35 +00:00
|
|
|
: name(name), numStates(0), cpTable(torch::Tensor()), parents(std::vector<Node*>()), children(std::vector<Node*>())
|
2023-06-29 20:00:41 +00:00
|
|
|
{
|
|
|
|
}
|
2023-07-25 23:39:01 +00:00
|
|
|
void Node::clear()
|
|
|
|
{
|
|
|
|
parents.clear();
|
|
|
|
children.clear();
|
|
|
|
cpTable = torch::Tensor();
|
|
|
|
dimensions.clear();
|
|
|
|
numStates = 0;
|
|
|
|
}
|
2023-11-08 17:45:35 +00:00
|
|
|
std::string Node::getName() const
|
2023-06-29 20:00:41 +00:00
|
|
|
{
|
|
|
|
return name;
|
|
|
|
}
|
|
|
|
void Node::addParent(Node* parent)
|
|
|
|
{
|
|
|
|
parents.push_back(parent);
|
|
|
|
}
|
2023-06-29 21:53:33 +00:00
|
|
|
void Node::removeParent(Node* parent)
|
|
|
|
{
|
|
|
|
parents.erase(std::remove(parents.begin(), parents.end(), parent), parents.end());
|
|
|
|
}
|
|
|
|
void Node::removeChild(Node* child)
|
|
|
|
{
|
|
|
|
children.erase(std::remove(children.begin(), children.end(), child), children.end());
|
|
|
|
}
|
2023-06-29 20:00:41 +00:00
|
|
|
void Node::addChild(Node* child)
|
|
|
|
{
|
|
|
|
children.push_back(child);
|
|
|
|
}
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<Node*>& Node::getParents()
|
2023-06-29 20:00:41 +00:00
|
|
|
{
|
|
|
|
return parents;
|
|
|
|
}
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<Node*>& Node::getChildren()
|
2023-06-29 20:00:41 +00:00
|
|
|
{
|
|
|
|
return children;
|
|
|
|
}
|
|
|
|
int Node::getNumStates() const
|
|
|
|
{
|
|
|
|
return numStates;
|
|
|
|
}
|
2023-07-01 12:45:44 +00:00
|
|
|
void Node::setNumStates(int numStates)
|
|
|
|
{
|
|
|
|
this->numStates = numStates;
|
|
|
|
}
|
2023-06-30 19:24:12 +00:00
|
|
|
torch::Tensor& Node::getCPT()
|
2023-06-29 20:00:41 +00:00
|
|
|
{
|
2023-07-05 16:38:54 +00:00
|
|
|
return cpTable;
|
2023-06-29 20:00:41 +00:00
|
|
|
}
|
2023-07-02 18:39:13 +00:00
|
|
|
/*
|
|
|
|
The MinFill criterion is a heuristic for variable elimination.
|
|
|
|
The variable that minimizes the number of edges that need to be added to the graph to make it triangulated.
|
|
|
|
This is done by counting the number of edges that need to be added to the graph if the variable is eliminated.
|
|
|
|
The variable with the minimum number of edges is chosen.
|
|
|
|
Here this is done computing the length of the combinations of the node neighbors taken 2 by 2.
|
|
|
|
*/
|
|
|
|
unsigned Node::minFill()
|
|
|
|
{
|
2023-11-08 17:45:35 +00:00
|
|
|
std::unordered_set<std::string> neighbors;
|
2023-07-02 18:39:13 +00:00
|
|
|
for (auto child : children) {
|
|
|
|
neighbors.emplace(child->getName());
|
|
|
|
}
|
|
|
|
for (auto parent : parents) {
|
|
|
|
neighbors.emplace(parent->getName());
|
|
|
|
}
|
2023-11-08 17:45:35 +00:00
|
|
|
auto source = std::vector<std::string>(neighbors.begin(), neighbors.end());
|
2023-07-11 15:42:20 +00:00
|
|
|
return combinations(source).size();
|
2023-07-02 18:39:13 +00:00
|
|
|
}
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<std::pair<std::string, std::string>> Node::combinations(const std::vector<std::string>& source)
|
2023-07-02 18:39:13 +00:00
|
|
|
{
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<std::pair<std::string, std::string>> result;
|
2023-07-02 18:39:13 +00:00
|
|
|
for (int i = 0; i < source.size(); ++i) {
|
2023-11-08 17:45:35 +00:00
|
|
|
std::string temp = source[i];
|
2023-07-02 18:39:13 +00:00
|
|
|
for (int j = i + 1; j < source.size(); ++j) {
|
2023-07-11 15:42:20 +00:00
|
|
|
result.push_back({ temp, source[j] });
|
2023-07-02 18:39:13 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return result;
|
|
|
|
}
|
2023-11-08 17:45:35 +00:00
|
|
|
void Node::computeCPT(const torch::Tensor& dataset, const std::vector<std::string>& features, const double laplaceSmoothing, const torch::Tensor& weights)
|
2023-07-02 18:39:13 +00:00
|
|
|
{
|
2023-08-05 12:40:42 +00:00
|
|
|
dimensions.clear();
|
2023-07-05 16:38:54 +00:00
|
|
|
// Get dimensions of the CPT
|
|
|
|
dimensions.push_back(numStates);
|
2023-07-29 22:04:18 +00:00
|
|
|
transform(parents.begin(), parents.end(), back_inserter(dimensions), [](const auto& parent) { return parent->getNumStates(); });
|
|
|
|
|
2023-07-05 16:38:54 +00:00
|
|
|
// Create a tensor of zeros with the dimensions of the CPT
|
|
|
|
cpTable = torch::zeros(dimensions, torch::kFloat) + laplaceSmoothing;
|
|
|
|
// Fill table with counts
|
2023-08-07 10:49:37 +00:00
|
|
|
auto pos = find(features.begin(), features.end(), name);
|
|
|
|
if (pos == features.end()) {
|
2023-11-08 17:45:35 +00:00
|
|
|
throw std::logic_error("Feature " + name + " not found in dataset");
|
2023-08-07 10:49:37 +00:00
|
|
|
}
|
|
|
|
int name_index = pos - features.begin();
|
|
|
|
for (int n_sample = 0; n_sample < dataset.size(1); ++n_sample) {
|
2023-09-02 11:58:12 +00:00
|
|
|
c10::List<c10::optional<at::Tensor>> coordinates;
|
2023-08-07 10:49:37 +00:00
|
|
|
coordinates.push_back(dataset.index({ name_index, n_sample }));
|
|
|
|
for (auto parent : parents) {
|
|
|
|
pos = find(features.begin(), features.end(), parent->getName());
|
|
|
|
if (pos == features.end()) {
|
2023-11-08 17:45:35 +00:00
|
|
|
throw std::logic_error("Feature parent " + parent->getName() + " not found in dataset");
|
2023-08-07 10:49:37 +00:00
|
|
|
}
|
|
|
|
int parent_index = pos - features.begin();
|
|
|
|
coordinates.push_back(dataset.index({ parent_index, n_sample }));
|
|
|
|
}
|
2023-07-05 16:38:54 +00:00
|
|
|
// Increment the count of the corresponding coordinate
|
2023-08-16 10:32:51 +00:00
|
|
|
cpTable.index_put_({ coordinates }, cpTable.index({ coordinates }) + weights.index({ n_sample }).item<double>());
|
2023-07-05 16:38:54 +00:00
|
|
|
}
|
|
|
|
// Normalize the counts
|
|
|
|
cpTable = cpTable / cpTable.sum(0);
|
|
|
|
}
|
2023-11-08 17:45:35 +00:00
|
|
|
float Node::getFactorValue(std::map<std::string, int>& evidence)
|
2023-07-05 16:38:54 +00:00
|
|
|
{
|
2023-09-02 11:58:12 +00:00
|
|
|
c10::List<c10::optional<at::Tensor>> coordinates;
|
2023-07-05 16:38:54 +00:00
|
|
|
// following predetermined order of indices in the cpTable (see Node.h)
|
2023-09-02 11:58:12 +00:00
|
|
|
coordinates.push_back(at::tensor(evidence[name]));
|
2023-11-08 17:45:35 +00:00
|
|
|
transform(parents.begin(), parents.end(), std::back_inserter(coordinates), [&evidence](const auto& parent) { return at::tensor(evidence[parent->getName()]); });
|
2023-07-05 16:38:54 +00:00
|
|
|
return cpTable.index({ coordinates }).item<float>();
|
2023-07-02 18:39:13 +00:00
|
|
|
}
|
2023-11-08 17:45:35 +00:00
|
|
|
std::vector<std::string> Node::graph(const std::string& className)
|
2023-07-15 23:20:47 +00:00
|
|
|
{
|
2023-11-08 17:45:35 +00:00
|
|
|
auto output = std::vector<std::string>();
|
2023-07-15 23:20:47 +00:00
|
|
|
auto suffix = name == className ? ", fontcolor=red, fillcolor=lightblue, style=filled " : "";
|
|
|
|
output.push_back(name + " [shape=circle" + suffix + "] \n");
|
2023-07-29 22:04:18 +00:00
|
|
|
transform(children.begin(), children.end(), back_inserter(output), [this](const auto& child) { return name + " -> " + child->getName(); });
|
2023-07-15 23:20:47 +00:00
|
|
|
return output;
|
|
|
|
}
|
2023-06-29 20:00:41 +00:00
|
|
|
}
|