Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include <thread>
8 : #include <mutex>
9 : #include <sstream>
10 : #include "Network.h"
11 : #include "bayesnet/utils/bayesnetUtils.h"
12 : namespace bayesnet {
13 4992 : Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
14 : {
15 4992 : }
16 22 : Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
17 : {
18 :
19 22 : }
20 4761 : Network::Network(const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
21 9522 : maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
22 : {
23 4761 : if (samples.defined())
24 11 : samples = samples.clone();
25 4816 : for (const auto& node : other.nodes) {
26 55 : nodes[node.first] = std::make_unique<Node>(*node.second);
27 : }
28 4761 : }
29 3358 : void Network::initialize()
30 : {
31 3358 : features.clear();
32 3358 : className = "";
33 3358 : classNumStates = 0;
34 3358 : fitted = false;
35 3358 : nodes.clear();
36 3358 : samples = torch::Tensor();
37 3358 : }
38 4794 : float Network::getMaxThreads() const
39 : {
40 4794 : return maxThreads;
41 : }
42 132 : torch::Tensor& Network::getSamples()
43 : {
44 132 : return samples;
45 : }
46 116878 : void Network::addNode(const std::string& name)
47 : {
48 116878 : if (name == "") {
49 22 : throw std::invalid_argument("Node name cannot be empty");
50 : }
51 116856 : if (nodes.find(name) != nodes.end()) {
52 0 : return;
53 : }
54 116856 : if (find(features.begin(), features.end(), name) == features.end()) {
55 116856 : features.push_back(name);
56 : }
57 116856 : nodes[name] = std::make_unique<Node>(name);
58 : }
59 607 : std::vector<std::string> Network::getFeatures() const
60 : {
61 607 : return features;
62 : }
63 5704 : int Network::getClassNumStates() const
64 : {
65 5704 : return classNumStates;
66 : }
67 132 : int Network::getStates() const
68 : {
69 132 : int result = 0;
70 792 : for (auto& node : nodes) {
71 660 : result += node.second->getNumStates();
72 : }
73 132 : return result;
74 : }
75 5150624 : std::string Network::getClassName() const
76 : {
77 5150624 : return className;
78 : }
79 295830 : bool Network::isCyclic(const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack)
80 : {
81 295830 : if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
82 : {
83 295830 : visited.insert(nodeId);
84 295830 : recStack.insert(nodeId);
85 367384 : for (Node* child : nodes[nodeId]->getChildren()) {
86 71620 : if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
87 66 : return true;
88 71576 : if (recStack.find(child->getName()) != recStack.end())
89 22 : return true;
90 : }
91 : }
92 295764 : recStack.erase(nodeId); // remove node from recursion stack before function ends
93 295764 : return false;
94 : }
95 224276 : void Network::addEdge(const std::string& parent, const std::string& child)
96 : {
97 224276 : if (nodes.find(parent) == nodes.end()) {
98 22 : throw std::invalid_argument("Parent node " + parent + " does not exist");
99 : }
100 224254 : if (nodes.find(child) == nodes.end()) {
101 22 : throw std::invalid_argument("Child node " + child + " does not exist");
102 : }
103 : // Temporarily add edge to check for cycles
104 224232 : nodes[parent]->addChild(nodes[child].get());
105 224232 : nodes[child]->addParent(nodes[parent].get());
106 224232 : std::unordered_set<std::string> visited;
107 224232 : std::unordered_set<std::string> recStack;
108 224232 : if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
109 : {
110 : // remove problematic edge
111 22 : nodes[parent]->removeChild(nodes[child].get());
112 22 : nodes[child]->removeParent(nodes[parent].get());
113 22 : throw std::invalid_argument("Adding this edge forms a cycle in the graph.");
114 : }
115 224254 : }
116 5151361 : std::map<std::string, std::unique_ptr<Node>>& Network::getNodes()
117 : {
118 5151361 : return nodes;
119 : }
120 3787 : void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
121 : {
122 3787 : if (weights.size(0) != n_samples) {
123 22 : throw std::invalid_argument("Weights (" + std::to_string(weights.size(0)) + ") must have the same number of elements as samples (" + std::to_string(n_samples) + ") in Network::fit");
124 : }
125 3765 : if (n_samples != n_samples_y) {
126 22 : throw std::invalid_argument("X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) + " != " + std::to_string(n_samples_y) + ")");
127 : }
128 3743 : if (n_features != featureNames.size()) {
129 22 : throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
130 : }
131 3721 : if (features.size() == 0) {
132 22 : throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()");
133 : }
134 3699 : if (n_features != features.size() - 1) {
135 22 : throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
136 : }
137 3677 : if (find(features.begin(), features.end(), className) == features.end()) {
138 22 : throw std::invalid_argument("Class Name not found in Network::features");
139 : }
140 121476 : for (auto& feature : featureNames) {
141 117843 : if (find(features.begin(), features.end(), feature) == features.end()) {
142 22 : throw std::invalid_argument("Feature " + feature + " not found in Network::features");
143 : }
144 117821 : if (states.find(feature) == states.end()) {
145 0 : throw std::invalid_argument("Feature " + feature + " not found in states");
146 : }
147 : }
148 3633 : }
149 3633 : void Network::setStates(const std::map<std::string, std::vector<int>>& states)
150 : {
151 : // Set states to every Node in the network
152 3633 : for_each(features.begin(), features.end(), [this, &states](const std::string& feature) {
153 121388 : nodes.at(feature)->setNumStates(states.at(feature).size());
154 121388 : });
155 3633 : classNumStates = nodes.at(className)->getNumStates();
156 3633 : }
157 : // X comes in nxm, where n is the number of features and m the number of samples
158 11 : void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
159 : {
160 11 : checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
161 11 : this->className = className;
162 11 : torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
163 33 : samples = torch::cat({ X , ytmp }, 0);
164 55 : for (int i = 0; i < featureNames.size(); ++i) {
165 132 : auto row_feature = X.index({ i, "..." });
166 44 : }
167 11 : completeFit(states, weights);
168 66 : }
169 3545 : void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
170 : {
171 3545 : checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
172 3545 : this->className = className;
173 3545 : this->samples = samples;
174 3545 : completeFit(states, weights);
175 3545 : }
176 : // input_data comes in nxm, where n is the number of features and m the number of samples
177 231 : void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
178 : {
179 231 : const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
180 231 : checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
181 77 : this->className = className;
182 : // Build tensor of samples (nxm) (n+1 because of the class)
183 77 : samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32);
184 385 : for (int i = 0; i < featureNames.size(); ++i) {
185 1232 : samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
186 : }
187 308 : samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
188 77 : completeFit(states, weights);
189 616 : }
190 3633 : void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
191 : {
192 3633 : setStates(states);
193 3633 : laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
194 3633 : std::vector<std::thread> threads;
195 125021 : for (auto& node : nodes) {
196 121388 : threads.emplace_back([this, &node, &weights]() {
197 121388 : node.second->computeCPT(samples, features, laplaceSmoothing, weights);
198 121388 : });
199 : }
200 125021 : for (auto& thread : threads) {
201 121388 : thread.join();
202 : }
203 3633 : fitted = true;
204 3633 : }
205 6802 : torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
206 : {
207 6802 : if (!fitted) {
208 22 : throw std::logic_error("You must call fit() before calling predict()");
209 : }
210 6780 : torch::Tensor result;
211 6780 : result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
212 1170049 : for (int i = 0; i < samples.size(1); ++i) {
213 3489873 : const torch::Tensor sample = samples.index({ "...", i });
214 1163291 : auto psample = predict_sample(sample);
215 1163269 : auto temp = torch::tensor(psample, torch::kFloat64);
216 : // result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64));
217 3489807 : result.index_put_({ i, "..." }, temp);
218 1163291 : }
219 6758 : if (proba)
220 3540 : return result;
221 6436 : return result.argmax(1);
222 2333340 : }
223 : // Return mxn tensor of probabilities
224 3540 : torch::Tensor Network::predict_proba(const torch::Tensor& samples)
225 : {
226 3540 : return predict_tensor(samples, true);
227 : }
228 :
229 : // Return mxn tensor of probabilities
230 3262 : torch::Tensor Network::predict(const torch::Tensor& samples)
231 : {
232 3262 : return predict_tensor(samples, false);
233 : }
234 :
235 : // Return mx1 std::vector of predictions
236 : // tsamples is nxm std::vector of samples
237 132 : std::vector<int> Network::predict(const std::vector<std::vector<int>>& tsamples)
238 : {
239 132 : if (!fitted) {
240 44 : throw std::logic_error("You must call fit() before calling predict()");
241 : }
242 88 : std::vector<int> predictions;
243 88 : std::vector<int> sample;
244 9801 : for (int row = 0; row < tsamples[0].size(); ++row) {
245 9735 : sample.clear();
246 72193 : for (int col = 0; col < tsamples.size(); ++col) {
247 62458 : sample.push_back(tsamples[col][row]);
248 : }
249 9735 : std::vector<double> classProbabilities = predict_sample(sample);
250 : // Find the class with the maximum posterior probability
251 9713 : auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
252 9713 : int predictedClass = distance(classProbabilities.begin(), maxElem);
253 9713 : predictions.push_back(predictedClass);
254 9713 : }
255 132 : return predictions;
256 110 : }
257 : // Return mxn std::vector of probabilities
258 : // tsamples is nxm std::vector of samples
259 777 : std::vector<std::vector<double>> Network::predict_proba(const std::vector<std::vector<int>>& tsamples)
260 : {
261 777 : if (!fitted) {
262 22 : throw std::logic_error("You must call fit() before calling predict_proba()");
263 : }
264 755 : std::vector<std::vector<double>> predictions;
265 755 : std::vector<int> sample;
266 146506 : for (int row = 0; row < tsamples[0].size(); ++row) {
267 145751 : sample.clear();
268 1941951 : for (int col = 0; col < tsamples.size(); ++col) {
269 1796200 : sample.push_back(tsamples[col][row]);
270 : }
271 145751 : predictions.push_back(predict_sample(sample));
272 : }
273 1510 : return predictions;
274 755 : }
275 55 : double Network::score(const std::vector<std::vector<int>>& tsamples, const std::vector<int>& labels)
276 : {
277 55 : std::vector<int> y_pred = predict(tsamples);
278 33 : int correct = 0;
279 6391 : for (int i = 0; i < y_pred.size(); ++i) {
280 6358 : if (y_pred[i] == labels[i]) {
281 5346 : correct++;
282 : }
283 : }
284 66 : return (double)correct / y_pred.size();
285 33 : }
286 : // Return 1xn std::vector of probabilities
287 155486 : std::vector<double> Network::predict_sample(const std::vector<int>& sample)
288 : {
289 : // Ensure the sample size is equal to the number of features
290 155486 : if (sample.size() != features.size() - 1) {
291 44 : throw std::invalid_argument("Sample size (" + std::to_string(sample.size()) +
292 66 : ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
293 : }
294 155464 : std::map<std::string, int> evidence;
295 2014056 : for (int i = 0; i < sample.size(); ++i) {
296 1858592 : evidence[features[i]] = sample[i];
297 : }
298 310928 : return exactInference(evidence);
299 155464 : }
300 : // Return 1xn std::vector of probabilities
301 1163291 : std::vector<double> Network::predict_sample(const torch::Tensor& sample)
302 : {
303 : // Ensure the sample size is equal to the number of features
304 1163291 : if (sample.size(0) != features.size() - 1) {
305 44 : throw std::invalid_argument("Sample size (" + std::to_string(sample.size(0)) +
306 66 : ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
307 : }
308 1163269 : std::map<std::string, int> evidence;
309 30202277 : for (int i = 0; i < sample.size(0); ++i) {
310 29039008 : evidence[features[i]] = sample[i].item<int>();
311 : }
312 2326538 : return exactInference(evidence);
313 1163269 : }
314 5150558 : double Network::computeFactor(std::map<std::string, int>& completeEvidence)
315 : {
316 5150558 : double result = 1.0;
317 72453396 : for (auto& node : getNodes()) {
318 67302838 : result *= node.second->getFactorValue(completeEvidence);
319 : }
320 5150558 : return result;
321 : }
322 1318733 : std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
323 : {
324 1318733 : std::vector<double> result(classNumStates, 0.0);
325 1318733 : std::vector<std::thread> threads;
326 1318733 : std::mutex mtx;
327 6469291 : for (int i = 0; i < classNumStates; ++i) {
328 5150558 : threads.emplace_back([this, &result, &evidence, i, &mtx]() {
329 5150558 : auto completeEvidence = std::map<std::string, int>(evidence);
330 5150558 : completeEvidence[getClassName()] = i;
331 5150558 : double factor = computeFactor(completeEvidence);
332 5150558 : std::lock_guard<std::mutex> lock(mtx);
333 5150558 : result[i] = factor;
334 5150558 : });
335 : }
336 6469291 : for (auto& thread : threads) {
337 5150558 : thread.join();
338 : }
339 : // Normalize result
340 1318733 : double sum = accumulate(result.begin(), result.end(), 0.0);
341 6469291 : transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; });
342 2637466 : return result;
343 1318733 : }
344 77 : std::vector<std::string> Network::show() const
345 : {
346 77 : std::vector<std::string> result;
347 : // Draw the network
348 440 : for (auto& node : nodes) {
349 363 : std::string line = node.first + " -> ";
350 847 : for (auto child : node.second->getChildren()) {
351 484 : line += child->getName() + ", ";
352 : }
353 363 : result.push_back(line);
354 363 : }
355 77 : return result;
356 0 : }
357 242 : std::vector<std::string> Network::graph(const std::string& title) const
358 : {
359 242 : auto output = std::vector<std::string>();
360 242 : auto prefix = "digraph BayesNet {\nlabel=<BayesNet ";
361 242 : auto suffix = ">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
362 242 : std::string header = prefix + title + suffix;
363 242 : output.push_back(header);
364 1925 : for (auto& node : nodes) {
365 1683 : auto result = node.second->graph(className);
366 1683 : output.insert(output.end(), result.begin(), result.end());
367 1683 : }
368 242 : output.push_back("}\n");
369 484 : return output;
370 242 : }
371 684 : std::vector<std::pair<std::string, std::string>> Network::getEdges() const
372 : {
373 684 : auto edges = std::vector<std::pair<std::string, std::string>>();
374 10684 : for (const auto& node : nodes) {
375 10000 : auto head = node.first;
376 27937 : for (const auto& child : node.second->getChildren()) {
377 17937 : auto tail = child->getName();
378 17937 : edges.push_back({ head, tail });
379 17937 : }
380 10000 : }
381 684 : return edges;
382 0 : }
383 563 : int Network::getNumEdges() const
384 : {
385 563 : return getEdges().size();
386 : }
387 605 : std::vector<std::string> Network::topological_sort()
388 : {
389 : /* Check if al the fathers of every node are before the node */
390 605 : auto result = features;
391 605 : result.erase(remove(result.begin(), result.end(), className), result.end());
392 605 : bool ending{ false };
393 1727 : while (!ending) {
394 1122 : ending = true;
395 10461 : for (auto feature : features) {
396 9339 : auto fathers = nodes[feature]->getParents();
397 24750 : for (const auto& father : fathers) {
398 15411 : auto fatherName = father->getName();
399 15411 : if (fatherName == className) {
400 8195 : continue;
401 : }
402 : // Check if father is placed before the actual feature
403 7216 : auto it = find(result.begin(), result.end(), fatherName);
404 7216 : if (it != result.end()) {
405 7216 : auto it2 = find(result.begin(), result.end(), feature);
406 7216 : if (it2 != result.end()) {
407 7216 : if (distance(it, it2) < 0) {
408 : // if it is not, insert it before the feature
409 671 : result.erase(remove(result.begin(), result.end(), fatherName), result.end());
410 671 : result.insert(it2, fatherName);
411 671 : ending = false;
412 : }
413 : } else {
414 0 : throw std::logic_error("Error in topological sort because of node " + feature + " is not in result");
415 : }
416 : } else {
417 0 : throw std::logic_error("Error in topological sort because of node father " + fatherName + " is not in result");
418 : }
419 15411 : }
420 9339 : }
421 : }
422 605 : return result;
423 0 : }
424 22 : std::string Network::dump_cpt() const
425 : {
426 22 : std::stringstream oss;
427 132 : for (auto& node : nodes) {
428 110 : oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
429 110 : oss << node.second->getCPT() << std::endl;
430 : }
431 44 : return oss.str();
432 22 : }
433 : }
|