BayesNet/bayesnet/network/Network.cc

507 lines
22 KiB
C++
Raw Permalink Normal View History

2024-04-11 16:02:49 +00:00
// ***************************************************************
// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
// SPDX-FileType: SOURCE
// SPDX-License-Identifier: MIT
// ***************************************************************
2023-07-06 09:59:48 +00:00
#include <thread>
2024-04-07 22:13:59 +00:00
#include <sstream>
2024-06-09 15:19:38 +00:00
#include <numeric>
#include <algorithm>
2023-06-29 20:00:41 +00:00
#include "Network.h"
2024-03-08 21:20:54 +00:00
#include "bayesnet/utils/bayesnetUtils.h"
#include "bayesnet/utils/CountingSemaphore.h"
#include <pthread.h>
#include <fstream>
2023-06-29 20:00:41 +00:00
namespace bayesnet {
Network::Network() : fitted{ false }, classNumStates{ 0 }
2023-07-02 14:15:14 +00:00
{
2024-04-07 22:13:59 +00:00
}
2024-06-11 09:40:45 +00:00
Network::Network(const Network& other) : features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
fitted(other.fitted), samples(other.samples)
2024-04-07 22:13:59 +00:00
{
if (samples.defined())
samples = samples.clone();
2023-11-08 17:45:35 +00:00
for (const auto& node : other.nodes) {
nodes[node.first] = std::make_unique<Node>(*node.second);
2023-06-29 20:00:41 +00:00
}
}
2023-08-03 18:22:33 +00:00
void Network::initialize()
{
2024-04-07 22:13:59 +00:00
features.clear();
2023-08-03 18:22:33 +00:00
className = "";
classNumStates = 0;
fitted = false;
nodes.clear();
samples = torch::Tensor();
}
torch::Tensor& Network::getSamples()
{
return samples;
}
2023-11-08 17:45:35 +00:00
void Network::addNode(const std::string& name)
2023-06-29 20:00:41 +00:00
{
if (fitted) {
throw std::invalid_argument("Cannot add node to a fitted network. Initialize first.");
}
2023-08-03 18:22:33 +00:00
if (name == "") {
2023-11-08 17:45:35 +00:00
throw std::invalid_argument("Node name cannot be empty");
2023-08-03 18:22:33 +00:00
}
2023-06-30 19:24:12 +00:00
if (nodes.find(name) != nodes.end()) {
return;
2023-06-30 19:24:12 +00:00
}
2023-08-05 12:40:42 +00:00
if (find(features.begin(), features.end(), name) == features.end()) {
features.push_back(name);
}
nodes[name] = std::make_unique<Node>(name);
2023-06-29 21:53:33 +00:00
}
2023-11-08 17:45:35 +00:00
std::vector<std::string> Network::getFeatures() const
{
return features;
}
2023-08-07 23:53:41 +00:00
int Network::getClassNumStates() const
2023-07-05 16:38:54 +00:00
{
return classNumStates;
}
2023-08-07 23:53:41 +00:00
int Network::getStates() const
2023-07-09 14:25:24 +00:00
{
int result = 0;
for (auto& node : nodes) {
2023-07-09 14:25:24 +00:00
result += node.second->getNumStates();
}
return result;
}
2023-11-08 17:45:35 +00:00
std::string Network::getClassName() const
2023-07-05 16:38:54 +00:00
{
return className;
}
2023-11-08 17:45:35 +00:00
bool Network::isCyclic(const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack)
2023-06-29 21:53:33 +00:00
{
if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
{
visited.insert(nodeId);
recStack.insert(nodeId);
for (Node* child : nodes[nodeId]->getChildren()) {
if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
return true;
if (recStack.find(child->getName()) != recStack.end())
2023-06-29 21:53:33 +00:00
return true;
}
}
recStack.erase(nodeId); // remove node from recursion stack before function ends
return false;
}
2023-11-08 17:45:35 +00:00
void Network::addEdge(const std::string& parent, const std::string& child)
2023-06-29 20:00:41 +00:00
{
if (fitted) {
throw std::invalid_argument("Cannot add edge to a fitted network. Initialize first.");
}
2023-06-29 20:00:41 +00:00
if (nodes.find(parent) == nodes.end()) {
2023-11-08 17:45:35 +00:00
throw std::invalid_argument("Parent node " + parent + " does not exist");
2023-06-29 20:00:41 +00:00
}
if (nodes.find(child) == nodes.end()) {
2023-11-08 17:45:35 +00:00
throw std::invalid_argument("Child node " + child + " does not exist");
2023-06-29 20:00:41 +00:00
}
2024-07-04 16:52:41 +00:00
// Check if the edge is already in the graph
for (auto& node : nodes[parent]->getChildren()) {
if (node->getName() == child) {
throw std::invalid_argument("Edge " + parent + " -> " + child + " already exists");
}
}
2023-06-29 21:53:33 +00:00
// Temporarily add edge to check for cycles
nodes[parent]->addChild(nodes[child].get());
nodes[child]->addParent(nodes[parent].get());
2023-11-08 17:45:35 +00:00
std::unordered_set<std::string> visited;
std::unordered_set<std::string> recStack;
2023-06-29 21:53:33 +00:00
if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
{
2023-06-30 19:24:12 +00:00
// remove problematic edge
nodes[parent]->removeChild(nodes[child].get());
nodes[child]->removeParent(nodes[parent].get());
2023-11-08 17:45:35 +00:00
throw std::invalid_argument("Adding this edge forms a cycle in the graph.");
2023-06-29 21:53:33 +00:00
}
2023-06-29 20:00:41 +00:00
}
2023-11-08 17:45:35 +00:00
std::map<std::string, std::unique_ptr<Node>>& Network::getNodes()
2023-06-29 20:00:41 +00:00
{
return nodes;
}
2023-11-08 17:45:35 +00:00
void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
2023-08-03 18:22:33 +00:00
{
2023-08-12 22:59:02 +00:00
if (weights.size(0) != n_samples) {
2023-11-08 17:45:35 +00:00
throw std::invalid_argument("Weights (" + std::to_string(weights.size(0)) + ") must have the same number of elements as samples (" + std::to_string(n_samples) + ") in Network::fit");
2023-08-12 22:59:02 +00:00
}
2023-08-03 18:22:33 +00:00
if (n_samples != n_samples_y) {
2023-11-08 17:45:35 +00:00
throw std::invalid_argument("X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) + " != " + std::to_string(n_samples_y) + ")");
2023-08-03 18:22:33 +00:00
}
if (n_features != featureNames.size()) {
2023-11-08 17:45:35 +00:00
throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
2023-08-03 18:22:33 +00:00
}
2024-04-07 22:13:59 +00:00
if (features.size() == 0) {
throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()");
}
2023-08-03 18:22:33 +00:00
if (n_features != features.size() - 1) {
2023-11-08 17:45:35 +00:00
throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
2023-08-03 18:22:33 +00:00
}
if (find(features.begin(), features.end(), className) == features.end()) {
2024-04-07 22:13:59 +00:00
throw std::invalid_argument("Class Name not found in Network::features");
2023-08-03 18:22:33 +00:00
}
for (auto& feature : featureNames) {
if (find(features.begin(), features.end(), feature) == features.end()) {
2023-11-08 17:45:35 +00:00
throw std::invalid_argument("Feature " + feature + " not found in Network::features");
2023-08-03 18:22:33 +00:00
}
if (states.find(feature) == states.end()) {
2023-11-08 17:45:35 +00:00
throw std::invalid_argument("Feature " + feature + " not found in states");
}
2023-08-03 18:22:33 +00:00
}
}
2023-11-08 17:45:35 +00:00
void Network::setStates(const std::map<std::string, std::vector<int>>& states)
2023-08-05 12:40:42 +00:00
{
// Set states to every Node in the network
2023-11-08 17:45:35 +00:00
for_each(features.begin(), features.end(), [this, &states](const std::string& feature) {
2023-09-10 17:50:36 +00:00
nodes.at(feature)->setNumStates(states.at(feature).size());
});
classNumStates = nodes.at(className)->getNumStates();
2023-08-05 12:40:42 +00:00
}
2023-08-03 18:22:33 +00:00
// X comes in nxm, where n is the number of features and m the number of samples
2024-06-11 09:40:45 +00:00
void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing)
{
2023-08-12 22:59:02 +00:00
checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
2023-07-25 23:39:01 +00:00
this->className = className;
2023-11-08 17:45:35 +00:00
torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
2023-08-03 18:22:33 +00:00
samples = torch::cat({ X , ytmp }, 0);
2023-07-25 23:39:01 +00:00
for (int i = 0; i < featureNames.size(); ++i) {
2023-08-03 18:22:33 +00:00
auto row_feature = X.index({ i, "..." });
2023-07-25 23:39:01 +00:00
}
2024-06-11 09:40:45 +00:00
completeFit(states, weights, smoothing);
}
2024-06-11 09:40:45 +00:00
void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing)
{
2023-08-12 22:59:02 +00:00
checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
this->className = className;
this->samples = samples;
2024-06-11 09:40:45 +00:00
completeFit(states, weights, smoothing);
}
2023-08-03 18:22:33 +00:00
// input_data comes in nxm, where n is the number of features and m the number of samples
2024-06-11 09:40:45 +00:00
void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const Smoothing_t smoothing)
2023-06-30 00:46:06 +00:00
{
2023-08-12 22:59:02 +00:00
const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
2023-06-30 19:24:12 +00:00
this->className = className;
// Build tensor of samples (nxm) (n+1 because of the class)
2023-08-03 18:22:33 +00:00
samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32);
2023-06-30 19:24:12 +00:00
for (int i = 0; i < featureNames.size(); ++i) {
2023-08-03 18:22:33 +00:00
samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
2023-06-30 19:24:12 +00:00
}
2023-08-03 18:22:33 +00:00
samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
2024-06-11 09:40:45 +00:00
completeFit(states, weights, smoothing);
2023-07-25 23:39:01 +00:00
}
2024-06-11 09:40:45 +00:00
void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights, const Smoothing_t smoothing)
2023-07-25 23:39:01 +00:00
{
setStates(states);
2023-11-08 17:45:35 +00:00
std::vector<std::thread> threads;
auto& semaphore = CountingSemaphore::getInstance();
2024-06-10 13:49:01 +00:00
const double n_samples = static_cast<double>(samples.size(1));
auto worker = [&](std::pair<const std::string, std::unique_ptr<Node>>& node, int i) {
std::string threadName = "FitWorker-" + std::to_string(i);
2024-06-21 17:56:35 +00:00
#if defined(__linux__)
pthread_setname_np(pthread_self(), threadName.c_str());
2024-06-21 17:56:35 +00:00
#else
pthread_setname_np(threadName.c_str());
#endif
2024-06-18 21:18:24 +00:00
double numStates = static_cast<double>(node.second->getNumStates());
double smoothing_factor;
2024-06-18 21:18:24 +00:00
switch (smoothing) {
case Smoothing_t::ORIGINAL:
smoothing_factor = 1.0 / n_samples;
break;
case Smoothing_t::LAPLACE:
smoothing_factor = 1.0;
break;
case Smoothing_t::CESTNIK:
smoothing_factor = 1 / numStates;
break;
default:
smoothing_factor = 0.0; // No smoothing
2024-06-18 21:18:24 +00:00
}
node.second->computeCPT(samples, features, smoothing_factor, weights);
semaphore.release();
2024-06-18 21:18:24 +00:00
};
int i = 0;
2023-08-31 18:30:28 +00:00
for (auto& node : nodes) {
2024-06-21 17:56:35 +00:00
semaphore.acquire();
threads.emplace_back(worker, std::ref(node), i++);
2023-08-31 18:30:28 +00:00
}
for (auto& thread : threads) {
thread.join();
}
fitted = true;
2023-06-29 20:00:41 +00:00
}
2023-08-03 18:22:33 +00:00
torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
2023-07-30 17:00:02 +00:00
{
if (!fitted) {
2023-11-08 17:45:35 +00:00
throw std::logic_error("You must call fit() before calling predict()");
2023-07-30 17:00:02 +00:00
}
2024-06-21 11:58:42 +00:00
// Ensure the sample size is equal to the number of features
if (samples.size(0) != features.size() - 1) {
throw std::invalid_argument("(T) Sample size (" + std::to_string(samples.size(0)) +
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
}
2023-08-03 18:22:33 +00:00
torch::Tensor result;
2024-06-21 11:58:42 +00:00
std::vector<std::thread> threads;
std::mutex mtx;
auto& semaphore = CountingSemaphore::getInstance();
2023-08-03 18:22:33 +00:00
result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
2024-06-21 11:58:42 +00:00
auto worker = [&](const torch::Tensor& sample, int i) {
std::string threadName = "PredictWorker-" + std::to_string(i);
2024-06-21 17:56:35 +00:00
#if defined(__linux__)
2024-06-21 11:58:42 +00:00
pthread_setname_np(pthread_self(), threadName.c_str());
2024-06-21 17:56:35 +00:00
#else
pthread_setname_np(threadName.c_str());
#endif
2023-08-04 17:42:18 +00:00
auto psample = predict_sample(sample);
auto temp = torch::tensor(psample, torch::kFloat64);
2024-06-21 11:58:42 +00:00
{
std::lock_guard<std::mutex> lock(mtx);
result.index_put_({ i, "..." }, temp);
}
semaphore.release();
};
for (int i = 0; i < samples.size(1); ++i) {
2024-06-21 17:56:35 +00:00
semaphore.acquire();
2024-06-21 11:58:42 +00:00
const torch::Tensor sample = samples.index({ "...", i });
threads.emplace_back(worker, sample, i);
}
for (auto& thread : threads) {
thread.join();
2023-07-30 17:00:02 +00:00
}
2023-08-03 18:22:33 +00:00
if (proba)
return result;
2023-10-09 09:25:30 +00:00
return result.argmax(1);
2023-07-30 17:00:02 +00:00
}
2023-08-03 18:22:33 +00:00
// Return mxn tensor of probabilities
2023-11-08 17:45:35 +00:00
torch::Tensor Network::predict_proba(const torch::Tensor& samples)
2023-08-03 18:22:33 +00:00
{
return predict_tensor(samples, true);
}
// Return mxn tensor of probabilities
2023-11-08 17:45:35 +00:00
torch::Tensor Network::predict(const torch::Tensor& samples)
2023-07-30 17:00:02 +00:00
{
2023-08-03 18:22:33 +00:00
return predict_tensor(samples, false);
2023-07-30 17:00:02 +00:00
}
2023-11-08 17:45:35 +00:00
// Return mx1 std::vector of predictions
// tsamples is nxm std::vector of samples
std::vector<int> Network::predict(const std::vector<std::vector<int>>& tsamples)
{
if (!fitted) {
2023-11-08 17:45:35 +00:00
throw std::logic_error("You must call fit() before calling predict()");
}
2024-06-21 11:58:42 +00:00
// Ensure the sample size is equal to the number of features
if (tsamples.size() != features.size() - 1) {
throw std::invalid_argument("(V) Sample size (" + std::to_string(tsamples.size()) +
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
}
std::vector<int> predictions(tsamples[0].size(), 0);
2023-11-08 17:45:35 +00:00
std::vector<int> sample;
2024-06-21 11:58:42 +00:00
std::vector<std::thread> threads;
auto& semaphore = CountingSemaphore::getInstance();
2024-06-23 11:02:40 +00:00
auto worker = [&](const std::vector<int>& sample, const int row, int& prediction) {
std::string threadName = "(V)PWorker-" + std::to_string(row);
#if defined(__linux__)
pthread_setname_np(pthread_self(), threadName.c_str());
#else
pthread_setname_np(threadName.c_str());
#endif
2024-06-21 11:58:42 +00:00
auto classProbabilities = predict_sample(sample);
auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
int predictedClass = distance(classProbabilities.begin(), maxElem);
2024-06-23 11:02:40 +00:00
prediction = predictedClass;
2024-06-21 11:58:42 +00:00
semaphore.release();
};
for (int row = 0; row < tsamples[0].size(); ++row) {
sample.clear();
for (int col = 0; col < tsamples.size(); ++col) {
sample.push_back(tsamples[col][row]);
}
2024-06-23 11:02:40 +00:00
semaphore.acquire();
threads.emplace_back(worker, sample, row, std::ref(predictions[row]));
2024-06-21 11:58:42 +00:00
}
for (auto& thread : threads) {
thread.join();
}
return predictions;
}
2023-11-08 17:45:35 +00:00
// Return mxn std::vector of probabilities
// tsamples is nxm std::vector of samples
2023-11-08 17:45:35 +00:00
std::vector<std::vector<double>> Network::predict_proba(const std::vector<std::vector<int>>& tsamples)
{
if (!fitted) {
2023-11-08 17:45:35 +00:00
throw std::logic_error("You must call fit() before calling predict_proba()");
}
2024-06-21 11:58:42 +00:00
// Ensure the sample size is equal to the number of features
if (tsamples.size() != features.size() - 1) {
throw std::invalid_argument("(V) Sample size (" + std::to_string(tsamples.size()) +
") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
}
2024-06-23 11:02:40 +00:00
std::vector<std::vector<double>> predictions(tsamples[0].size(), std::vector<double>(classNumStates, 0.0));
2023-11-08 17:45:35 +00:00
std::vector<int> sample;
2024-06-23 11:02:40 +00:00
std::vector<std::thread> threads;
auto& semaphore = CountingSemaphore::getInstance();
auto worker = [&](const std::vector<int>& sample, int row, std::vector<double>& predictions) {
std::string threadName = "(V)PWorker-" + std::to_string(row);
#if defined(__linux__)
pthread_setname_np(pthread_self(), threadName.c_str());
#else
pthread_setname_np(threadName.c_str());
#endif
std::vector<double> classProbabilities = predict_sample(sample);
predictions = classProbabilities;
semaphore.release();
};
for (int row = 0; row < tsamples[0].size(); ++row) {
sample.clear();
for (int col = 0; col < tsamples.size(); ++col) {
sample.push_back(tsamples[col][row]);
}
2024-06-23 11:02:40 +00:00
semaphore.acquire();
threads.emplace_back(worker, sample, row, std::ref(predictions[row]));
}
for (auto& thread : threads) {
thread.join();
}
return predictions;
}
2023-11-08 17:45:35 +00:00
double Network::score(const std::vector<std::vector<int>>& tsamples, const std::vector<int>& labels)
{
2023-11-08 17:45:35 +00:00
std::vector<int> y_pred = predict(tsamples);
int correct = 0;
for (int i = 0; i < y_pred.size(); ++i) {
if (y_pred[i] == labels[i]) {
correct++;
}
}
return (double)correct / y_pred.size();
}
2023-11-08 17:45:35 +00:00
// Return 1xn std::vector of probabilities
std::vector<double> Network::predict_sample(const std::vector<int>& sample)
2023-07-02 14:31:50 +00:00
{
2023-11-08 17:45:35 +00:00
std::map<std::string, int> evidence;
for (int i = 0; i < sample.size(); ++i) {
evidence[features[i]] = sample[i];
2023-07-02 14:31:50 +00:00
}
2023-07-06 22:33:04 +00:00
return exactInference(evidence);
2023-07-02 14:31:50 +00:00
}
2023-11-08 17:45:35 +00:00
// Return 1xn std::vector of probabilities
std::vector<double> Network::predict_sample(const torch::Tensor& sample)
2023-07-30 17:00:02 +00:00
{
2023-11-08 17:45:35 +00:00
std::map<std::string, int> evidence;
2023-07-30 17:00:02 +00:00
for (int i = 0; i < sample.size(0); ++i) {
evidence[features[i]] = sample[i].item<int>();
}
return exactInference(evidence);
}
2023-11-08 17:45:35 +00:00
std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
2023-07-06 09:01:58 +00:00
{
2023-11-08 17:45:35 +00:00
std::vector<double> result(classNumStates, 0.0);
2024-06-21 11:58:42 +00:00
auto completeEvidence = std::map<std::string, int>(evidence);
for (int i = 0; i < classNumStates; ++i) {
2024-06-18 21:18:24 +00:00
completeEvidence[getClassName()] = i;
2024-06-21 11:58:42 +00:00
double partial = 1.0;
for (auto& node : getNodes()) {
partial *= node.second->getFactorValue(completeEvidence);
2024-06-18 21:18:24 +00:00
}
2024-06-21 11:58:42 +00:00
result[i] = partial;
2023-07-06 09:01:58 +00:00
}
// Normalize result
2024-06-09 15:19:38 +00:00
double sum = std::accumulate(result.begin(), result.end(), 0.0);
2023-08-16 10:32:51 +00:00
transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; });
2023-07-06 09:01:58 +00:00
return result;
}
2023-11-08 17:45:35 +00:00
std::vector<std::string> Network::show() const
2023-07-13 14:59:06 +00:00
{
2023-11-08 17:45:35 +00:00
std::vector<std::string> result;
2023-07-13 14:59:06 +00:00
// Draw the network
for (auto& node : nodes) {
2023-11-08 17:45:35 +00:00
std::string line = node.first + " -> ";
2023-07-13 14:59:06 +00:00
for (auto child : node.second->getChildren()) {
line += child->getName() + ", ";
}
result.push_back(line);
}
return result;
}
2023-11-08 17:45:35 +00:00
std::vector<std::string> Network::graph(const std::string& title) const
2023-07-15 23:20:47 +00:00
{
2023-11-08 17:45:35 +00:00
auto output = std::vector<std::string>();
2023-07-15 23:20:47 +00:00
auto prefix = "digraph BayesNet {\nlabel=<BayesNet ";
auto suffix = ">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
2023-11-08 17:45:35 +00:00
std::string header = prefix + title + suffix;
2023-07-15 23:20:47 +00:00
output.push_back(header);
for (auto& node : nodes) {
auto result = node.second->graph(className);
output.insert(output.end(), result.begin(), result.end());
}
output.push_back("}\n");
return output;
}
2023-11-08 17:45:35 +00:00
std::vector<std::pair<std::string, std::string>> Network::getEdges() const
{
2023-11-08 17:45:35 +00:00
auto edges = std::vector<std::pair<std::string, std::string>>();
for (const auto& node : nodes) {
auto head = node.first;
for (const auto& child : node.second->getChildren()) {
auto tail = child->getName();
edges.push_back({ head, tail });
}
}
return edges;
}
2023-08-07 23:53:41 +00:00
int Network::getNumEdges() const
{
return getEdges().size();
}
2023-11-08 17:45:35 +00:00
std::vector<std::string> Network::topological_sort()
2023-08-01 22:56:52 +00:00
{
/* Check if al the fathers of every node are before the node */
auto result = features;
2023-08-03 18:22:33 +00:00
result.erase(remove(result.begin(), result.end(), className), result.end());
2023-08-01 22:56:52 +00:00
bool ending{ false };
while (!ending) {
ending = true;
for (auto feature : features) {
auto fathers = nodes[feature]->getParents();
for (const auto& father : fathers) {
auto fatherName = father->getName();
if (fatherName == className) {
continue;
}
2023-08-03 18:22:33 +00:00
// Check if father is placed before the actual feature
2023-08-01 22:56:52 +00:00
auto it = find(result.begin(), result.end(), fatherName);
if (it != result.end()) {
auto it2 = find(result.begin(), result.end(), feature);
if (it2 != result.end()) {
if (distance(it, it2) < 0) {
2023-08-03 18:22:33 +00:00
// if it is not, insert it before the feature
2023-08-01 22:56:52 +00:00
result.erase(remove(result.begin(), result.end(), fatherName), result.end());
result.insert(it2, fatherName);
ending = false;
}
}
}
}
}
}
return result;
}
2024-04-07 22:13:59 +00:00
std::string Network::dump_cpt() const
2023-08-03 18:22:33 +00:00
{
2024-04-07 22:13:59 +00:00
std::stringstream oss;
2023-08-03 18:22:33 +00:00
for (auto& node : nodes) {
2024-04-07 22:13:59 +00:00
oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
oss << node.second->getCPT() << std::endl;
2023-08-03 18:22:33 +00:00
}
2024-04-07 22:13:59 +00:00
return oss.str();
2023-08-03 18:22:33 +00:00
}
2023-06-29 20:00:41 +00:00
}