1// ***************************************************************
2// SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3// SPDX-FileType: SOURCE
4// SPDX-License-Identifier: MIT
5// ***************************************************************
11#include "bayesnet/utils/bayesnetUtils.h"
13 Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
16 Network::Network(
float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
20 Network::Network(
const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
21 maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
23 if (samples.defined())
24 samples = samples.clone();
25 for (
const auto& node : other.nodes) {
26 nodes[node.first] = std::make_unique<Node>(*node.second);
29 void Network::initialize()
36 samples = torch::Tensor();
38 float Network::getMaxThreads()
const
42 torch::Tensor& Network::getSamples()
46 void Network::addNode(
const std::string& name)
49 throw std::invalid_argument(
"Node name cannot be empty");
51 if (nodes.find(name) != nodes.end()) {
54 if (find(features.begin(), features.end(), name) == features.end()) {
55 features.push_back(name);
57 nodes[name] = std::make_unique<Node>(name);
59 std::vector<std::string> Network::getFeatures()
const
63 int Network::getClassNumStates()
const
65 return classNumStates;
67 int Network::getStates()
const
70 for (
auto& node : nodes) {
71 result += node.second->getNumStates();
75 std::string Network::getClassName()
const
79 bool Network::isCyclic(
const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack)
81 if (visited.find(nodeId) == visited.end())
// if node hasn't been visited yet
83 visited.insert(nodeId);
84 recStack.insert(nodeId);
85 for (Node* child : nodes[nodeId]->getChildren()) {
86 if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
88 if (recStack.find(child->getName()) != recStack.end())
92 recStack.erase(nodeId);
// remove node from recursion stack before function ends
95 void Network::addEdge(
const std::string& parent,
const std::string& child)
97 if (nodes.find(parent) == nodes.end()) {
98 throw std::invalid_argument(
"Parent node " + parent +
" does not exist");
100 if (nodes.find(child) == nodes.end()) {
101 throw std::invalid_argument(
"Child node " + child +
" does not exist");
103 // Temporarily add edge to check for cycles
104 nodes[parent]->addChild(nodes[child].get());
105 nodes[child]->addParent(nodes[parent].get());
106 std::unordered_set<std::string> visited;
107 std::unordered_set<std::string> recStack;
108 if (isCyclic(nodes[child]->getName(), visited, recStack))
// if adding this edge forms a cycle
110 // remove problematic edge
111 nodes[parent]->removeChild(nodes[child].get());
112 nodes[child]->removeParent(nodes[parent].get());
113 throw std::invalid_argument(
"Adding this edge forms a cycle in the graph.");
116 std::map<std::string, std::unique_ptr<Node>>& Network::getNodes()
120 void Network::checkFitData(
int n_samples,
int n_features,
int n_samples_y,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states,
const torch::Tensor& weights)
122 if (weights.size(0) != n_samples) {
123 throw std::invalid_argument(
"Weights (" + std::to_string(weights.size(0)) +
") must have the same number of elements as samples (" + std::to_string(n_samples) +
") in Network::fit");
125 if (n_samples != n_samples_y) {
126 throw std::invalid_argument(
"X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) +
" != " + std::to_string(n_samples_y) +
")");
128 if (n_features != featureNames.size()) {
129 throw std::invalid_argument(
"X and features must have the same number of features in Network::fit (" + std::to_string(n_features) +
" != " + std::to_string(featureNames.size()) +
")");
131 if (features.size() == 0) {
132 throw std::invalid_argument(
"The network has not been initialized. You must call addNode() before calling fit()");
134 if (n_features != features.size() - 1) {
135 throw std::invalid_argument(
"X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) +
" != " + std::to_string(features.size() - 1) +
")");
137 if (find(features.begin(), features.end(), className) == features.end()) {
138 throw std::invalid_argument(
"Class Name not found in Network::features");
140 for (
auto& feature : featureNames) {
141 if (find(features.begin(), features.end(), feature) == features.end()) {
142 throw std::invalid_argument(
"Feature " + feature +
" not found in Network::features");
144 if (states.find(feature) == states.end()) {
145 throw std::invalid_argument(
"Feature " + feature +
" not found in states");
149 void Network::setStates(
const std::map<std::string, std::vector<int>>& states)
151 // Set states to every Node in the network
152 for_each(features.begin(), features.end(), [
this, &states](
const std::string& feature) {
153 nodes.at(feature)->setNumStates(states.at(feature).size());
155 classNumStates = nodes.at(className)->getNumStates();
157 // X comes in nxm, where n is the number of features and m the number of samples
158 void Network::fit(
const torch::Tensor& X,
const torch::Tensor& y,
const torch::Tensor& weights,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states)
160 checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
161 this->className = className;
162 torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
163 samples = torch::cat({ X , ytmp }, 0);
164 for (
int i = 0; i < featureNames.size(); ++i) {
165 auto row_feature = X.index({ i,
"..." });
167 completeFit(states, weights);
169 void Network::fit(
const torch::Tensor& samples,
const torch::Tensor& weights,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states)
171 checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
172 this->className = className;
173 this->samples = samples;
174 completeFit(states, weights);
176 // input_data comes in nxm, where n is the number of features and m the number of samples
177 void Network::fit(
const std::vector<std::vector<int>>& input_data,
const std::vector<int>& labels,
const std::vector<double>& weights_,
const std::vector<std::string>& featureNames,
const std::string& className,
const std::map<std::string, std::vector<int>>& states)
179 const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
180 checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
181 this->className = className;
182 // Build tensor of samples (nxm) (n+1 because of the class)
183 samples = torch::zeros({
static_cast<int>(input_data.size() + 1),
static_cast<int>(input_data[0].size()) }, torch::kInt32);
184 for (
int i = 0; i < featureNames.size(); ++i) {
185 samples.index_put_({ i,
"..." }, torch::tensor(input_data[i], torch::kInt32));
187 samples.index_put_({ -1,
"..." }, torch::tensor(labels, torch::kInt32));
188 completeFit(states, weights);
190 void Network::completeFit(
const std::map<std::string, std::vector<int>>& states,
const torch::Tensor& weights)
193 laplaceSmoothing = 1.0 / samples.size(1);
// To use in CPT computation
194 std::vector<std::thread> threads;
195 for (
auto& node : nodes) {
196 threads.emplace_back([
this, &node, &weights]() {
197 node.second->computeCPT(samples, features, laplaceSmoothing, weights);
200 for (
auto& thread : threads) {
205 torch::Tensor Network::predict_tensor(
const torch::Tensor& samples,
const bool proba)
208 throw std::logic_error(
"You must call fit() before calling predict()");
210 torch::Tensor result;
211 result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
212 for (
int i = 0; i < samples.size(1); ++i) {
213 const torch::Tensor sample = samples.index({
"...", i });
214 auto psample = predict_sample(sample);
215 auto temp = torch::tensor(psample, torch::kFloat64);
216 // result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64));
217 result.index_put_({ i,
"..." }, temp);
221 return result.argmax(1);
223 // Return mxn tensor of probabilities
224 torch::Tensor Network::predict_proba(
const torch::Tensor& samples)
226 return predict_tensor(samples,
true);
229 // Return mxn tensor of probabilities
230 torch::Tensor Network::predict(
const torch::Tensor& samples)
232 return predict_tensor(samples,
false);
235 // Return mx1 std::vector of predictions
236 // tsamples is nxm std::vector of samples
237 std::vector<int> Network::predict(
const std::vector<std::vector<int>>& tsamples)
240 throw std::logic_error(
"You must call fit() before calling predict()");
242 std::vector<int> predictions;
243 std::vector<int> sample;
244 for (
int row = 0; row < tsamples[0].size(); ++row) {
246 for (
int col = 0; col < tsamples.size(); ++col) {
247 sample.push_back(tsamples[col][row]);
249 std::vector<double> classProbabilities = predict_sample(sample);
250 // Find the class with the maximum posterior probability
251 auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
252 int predictedClass = distance(classProbabilities.begin(), maxElem);
253 predictions.push_back(predictedClass);
257 // Return mxn std::vector of probabilities
258 // tsamples is nxm std::vector of samples
259 std::vector<std::vector<double>> Network::predict_proba(
const std::vector<std::vector<int>>& tsamples)
262 throw std::logic_error(
"You must call fit() before calling predict_proba()");
264 std::vector<std::vector<double>> predictions;
265 std::vector<int> sample;
266 for (
int row = 0; row < tsamples[0].size(); ++row) {
268 for (
int col = 0; col < tsamples.size(); ++col) {
269 sample.push_back(tsamples[col][row]);
271 predictions.push_back(predict_sample(sample));
275 double Network::score(
const std::vector<std::vector<int>>& tsamples,
const std::vector<int>& labels)
277 std::vector<int> y_pred = predict(tsamples);
279 for (
int i = 0; i < y_pred.size(); ++i) {
280 if (y_pred[i] == labels[i]) {
284 return (
double)correct / y_pred.size();
286 // Return 1xn std::vector of probabilities
287 std::vector<double> Network::predict_sample(
const std::vector<int>& sample)
289 // Ensure the sample size is equal to the number of features
290 if (sample.size() != features.size() - 1) {
291 throw std::invalid_argument(
"Sample size (" + std::to_string(sample.size()) +
292 ") does not match the number of features (" + std::to_string(features.size() - 1) +
")");
294 std::map<std::string, int> evidence;
295 for (
int i = 0; i < sample.size(); ++i) {
296 evidence[features[i]] = sample[i];
298 return exactInference(evidence);
300 // Return 1xn std::vector of probabilities
301 std::vector<double> Network::predict_sample(
const torch::Tensor& sample)
303 // Ensure the sample size is equal to the number of features
304 if (sample.size(0) != features.size() - 1) {
305 throw std::invalid_argument(
"Sample size (" + std::to_string(sample.size(0)) +
306 ") does not match the number of features (" + std::to_string(features.size() - 1) +
")");
308 std::map<std::string, int> evidence;
309 for (
int i = 0; i < sample.size(0); ++i) {
310 evidence[features[i]] = sample[i].item<
int>();
312 return exactInference(evidence);
314 double Network::computeFactor(std::map<std::string, int>& completeEvidence)
317 for (
auto& node : getNodes()) {
318 result *= node.second->getFactorValue(completeEvidence);
322 std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
324 std::vector<double> result(classNumStates, 0.0);
325 std::vector<std::thread> threads;
327 for (
int i = 0; i < classNumStates; ++i) {
328 threads.emplace_back([
this, &result, &evidence, i, &mtx]() {
329 auto completeEvidence = std::map<std::string, int>(evidence);
330 completeEvidence[getClassName()] = i;
331 double factor = computeFactor(completeEvidence);
332 std::lock_guard<std::mutex> lock(mtx);
336 for (
auto& thread : threads) {
340 double sum = accumulate(result.begin(), result.end(), 0.0);
341 transform(result.begin(), result.end(), result.begin(), [sum](
const double& value) { return value / sum; });
344 std::vector<std::string> Network::show()
const
346 std::vector<std::string> result;
348 for (
auto& node : nodes) {
349 std::string line = node.first +
" -> ";
350 for (
auto child : node.second->getChildren()) {
351 line += child->getName() +
", ";
353 result.push_back(line);
357 std::vector<std::string> Network::graph(
const std::string& title)
const
359 auto output = std::vector<std::string>();
360 auto prefix =
"digraph BayesNet {\nlabel=<BayesNet ";
361 auto suffix =
">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
362 std::string header = prefix + title + suffix;
363 output.push_back(header);
364 for (
auto& node : nodes) {
365 auto result = node.second->graph(className);
366 output.insert(output.end(), result.begin(), result.end());
368 output.push_back(
"}\n");
371 std::vector<std::pair<std::string, std::string>> Network::getEdges()
const
373 auto edges = std::vector<std::pair<std::string, std::string>>();
374 for (
const auto& node : nodes) {
375 auto head = node.first;
376 for (
const auto& child : node.second->getChildren()) {
377 auto tail = child->getName();
378 edges.push_back({ head, tail });
383 int Network::getNumEdges()
const
385 return getEdges().size();
387 std::vector<std::string> Network::topological_sort()
389 /* Check if al the fathers of every node are before the node */
390 auto result = features;
391 result.erase(remove(result.begin(), result.end(), className), result.end());
392 bool ending{
false };
395 for (
auto feature : features) {
396 auto fathers = nodes[feature]->getParents();
397 for (
const auto& father : fathers) {
398 auto fatherName = father->getName();
399 if (fatherName == className) {
402 // Check if father is placed before the actual feature
403 auto it = find(result.begin(), result.end(), fatherName);
404 if (it != result.end()) {
405 auto it2 = find(result.begin(), result.end(), feature);
406 if (it2 != result.end()) {
407 if (distance(it, it2) < 0) {
408 // if it is not, insert it before the feature
409 result.erase(remove(result.begin(), result.end(), fatherName), result.end());
410 result.insert(it2, fatherName);
420 std::string Network::dump_cpt()
const
422 std::stringstream oss;
423 for (
auto& node : nodes) {
424 oss <<
"* " << node.first <<
": (" << node.second->getNumStates() <<
") : " << node.second->getCPT().sizes() << std::endl;
425 oss << node.second->getCPT() << std::endl;