63 KiB
63 KiB
<html lang="en">
<head>
</head>
</html>
LCOV - code coverage report | |||||||||||||||||||||||||
![]() | |||||||||||||||||||||||||
|
|||||||||||||||||||||||||
![]() |
Line data Source code 1 : // *************************************************************** 2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez 3 : // SPDX-FileType: SOURCE 4 : // SPDX-License-Identifier: MIT 5 : // *************************************************************** 6 : 7 : #include <thread> 8 : #include <mutex> 9 : #include <sstream> 10 : #include "Network.h" 11 : #include "bayesnet/utils/bayesnetUtils.h" 12 : namespace bayesnet { 13 2332 : Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 } 14 : { 15 2332 : } 16 8 : Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 } 17 : { 18 : 19 8 : } 20 2244 : Network::Network(const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()), 21 4488 : maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples) 22 : { 23 2244 : if (samples.defined()) 24 4 : samples = samples.clone(); 25 2264 : for (const auto& node : other.nodes) { 26 20 : nodes[node.first] = std::make_unique<Node>(*node.second); 27 : } 28 2244 : } 29 1740 : void Network::initialize() 30 : { 31 1740 : features.clear(); 32 1740 : className = ""; 33 1740 : classNumStates = 0; 34 1740 : fitted = false; 35 1740 : nodes.clear(); 36 1740 : samples = torch::Tensor(); 37 1740 : } 38 2256 : float Network::getMaxThreads() const 39 : { 40 2256 : return maxThreads; 41 : } 42 48 : torch::Tensor& Network::getSamples() 43 : { 44 48 : return samples; 45 : } 46 31216 : void Network::addNode(const std::string& name) 47 : { 48 31216 : if (name == "") { 49 8 : throw std::invalid_argument("Node name cannot be empty"); 50 : } 51 31208 : if (nodes.find(name) != nodes.end()) { 52 4 : return; 53 : } 54 31204 : if (find(features.begin(), features.end(), name) == features.end()) { 55 31204 : features.push_back(name); 56 : } 57 31204 : nodes[name] = std::make_unique<Node>(name); 58 : } 59 380 : std::vector<std::string> Network::getFeatures() const 60 : { 61 380 : return features; 62 : } 63 2616 : int Network::getClassNumStates() const 64 : { 65 2616 : return classNumStates; 66 : } 67 48 : int Network::getStates() const 68 : { 69 48 : int result = 0; 70 288 : for (auto& node : nodes) { 71 240 : result += node.second->getNumStates(); 72 : } 73 48 : return result; 74 : } 75 3735008 : std::string Network::getClassName() const 76 : { 77 3735008 : return className; 78 : } 79 70324 : bool Network::isCyclic(const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack) 80 : { 81 70324 : if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet 82 : { 83 70324 : visited.insert(nodeId); 84 70324 : recStack.insert(nodeId); 85 81496 : for (Node* child : nodes[nodeId]->getChildren()) { 86 11196 : if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack)) 87 24 : return true; 88 11180 : if (recStack.find(child->getName()) != recStack.end()) 89 8 : return true; 90 : } 91 : } 92 70300 : recStack.erase(nodeId); // remove node from recursion stack before function ends 93 70300 : return false; 94 : } 95 59152 : void Network::addEdge(const std::string& parent, const std::string& child) 96 : { 97 59152 : if (nodes.find(parent) == nodes.end()) { 98 8 : throw std::invalid_argument("Parent node " + parent + " does not exist"); 99 : } 100 59144 : if (nodes.find(child) == nodes.end()) { 101 8 : throw std::invalid_argument("Child node " + child + " does not exist"); 102 : } 103 : // Temporarily add edge to check for cycles 104 59136 : nodes[parent]->addChild(nodes[child].get()); 105 59136 : nodes[child]->addParent(nodes[parent].get()); 106 59136 : std::unordered_set<std::string> visited; 107 59136 : std::unordered_set<std::string> recStack; 108 59136 : if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle 109 : { 110 : // remove problematic edge 111 8 : nodes[parent]->removeChild(nodes[child].get()); 112 8 : nodes[child]->removeParent(nodes[parent].get()); 113 8 : throw std::invalid_argument("Adding this edge forms a cycle in the graph."); 114 : } 115 59144 : } 116 3735276 : std::map<std::string, std::unique_ptr<Node>>& Network::getNodes() 117 : { 118 3735276 : return nodes; 119 : } 120 1888 : void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) 121 : { 122 1888 : if (weights.size(0) != n_samples) { 123 8 : throw std::invalid_argument("Weights (" + std::to_string(weights.size(0)) + ") must have the same number of elements as samples (" + std::to_string(n_samples) + ") in Network::fit"); 124 : } 125 1880 : if (n_samples != n_samples_y) { 126 8 : throw std::invalid_argument("X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) + " != " + std::to_string(n_samples_y) + ")"); 127 : } 128 1872 : if (n_features != featureNames.size()) { 129 8 : throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")"); 130 : } 131 1864 : if (features.size() == 0) { 132 8 : throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()"); 133 : } 134 1856 : if (n_features != features.size() - 1) { 135 8 : throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")"); 136 : } 137 1848 : if (find(features.begin(), features.end(), className) == features.end()) { 138 8 : throw std::invalid_argument("Class Name not found in Network::features"); 139 : } 140 32868 : for (auto& feature : featureNames) { 141 31044 : if (find(features.begin(), features.end(), feature) == features.end()) { 142 8 : throw std::invalid_argument("Feature " + feature + " not found in Network::features"); 143 : } 144 31036 : if (states.find(feature) == states.end()) { 145 8 : throw std::invalid_argument("Feature " + feature + " not found in states"); 146 : } 147 : } 148 1824 : } 149 1824 : void Network::setStates(const std::map<std::string, std::vector<int>>& states) 150 : { 151 : // Set states to every Node in the network 152 1824 : for_each(features.begin(), features.end(), [this, &states](const std::string& feature) { 153 32828 : nodes.at(feature)->setNumStates(states.at(feature).size()); 154 32828 : }); 155 1824 : classNumStates = nodes.at(className)->getNumStates(); 156 1824 : } 157 : // X comes in nxm, where n is the number of features and m the number of samples 158 4 : void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states) 159 : { 160 4 : checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights); 161 4 : this->className = className; 162 4 : torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1); 163 12 : samples = torch::cat({ X , ytmp }, 0); 164 20 : for (int i = 0; i < featureNames.size(); ++i) { 165 48 : auto row_feature = X.index({ i, "..." }); 166 16 : } 167 4 : completeFit(states, weights); 168 24 : } 169 1792 : void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states) 170 : { 171 1792 : checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights); 172 1792 : this->className = className; 173 1792 : this->samples = samples; 174 1792 : completeFit(states, weights); 175 1792 : } 176 : // input_data comes in nxm, where n is the number of features and m the number of samples 177 92 : void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states) 178 : { 179 92 : const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64); 180 92 : checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights); 181 28 : this->className = className; 182 : // Build tensor of samples (nxm) (n+1 because of the class) 183 28 : samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32); 184 140 : for (int i = 0; i < featureNames.size(); ++i) { 185 448 : samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32)); 186 : } 187 112 : samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32)); 188 28 : completeFit(states, weights); 189 232 : } 190 1824 : void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights) 191 : { 192 1824 : setStates(states); 193 1824 : laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation 194 1824 : std::vector<std::thread> threads; 195 34652 : for (auto& node : nodes) { 196 32828 : threads.emplace_back([this, &node, &weights]() { 197 32828 : node.second->computeCPT(samples, features, laplaceSmoothing, weights); 198 32828 : }); 199 : } 200 34652 : for (auto& thread : threads) { 201 32828 : thread.join(); 202 : } 203 1824 : fitted = true; 204 1824 : } 205 3320 : torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba) 206 : { 207 3320 : if (!fitted) { 208 8 : throw std::logic_error("You must call fit() before calling predict()"); 209 : } 210 3312 : torch::Tensor result; 211 3312 : result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64); 212 785016 : for (int i = 0; i < samples.size(1); ++i) { 213 2345136 : const torch::Tensor sample = samples.index({ "...", i }); 214 781712 : auto psample = predict_sample(sample); 215 781704 : auto temp = torch::tensor(psample, torch::kFloat64); 216 : // result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64)); 217 2345112 : result.index_put_({ i, "..." }, temp); 218 781712 : } 219 3304 : if (proba) 220 1476 : return result; 221 3656 : return result.argmax(1); 222 1566728 : } 223 : // Return mxn tensor of probabilities 224 1476 : torch::Tensor Network::predict_proba(const torch::Tensor& samples) 225 : { 226 1476 : return predict_tensor(samples, true); 227 : } 228 : 229 : // Return mxn tensor of probabilities 230 1844 : torch::Tensor Network::predict(const torch::Tensor& samples) 231 : { 232 1844 : return predict_tensor(samples, false); 233 : } 234 : 235 : // Return mx1 std::vector of predictions 236 : // tsamples is nxm std::vector of samples 237 48 : std::vector<int> Network::predict(const std::vector<std::vector<int>>& tsamples) 238 : { 239 48 : if (!fitted) { 240 16 : throw std::logic_error("You must call fit() before calling predict()"); 241 : } 242 32 : std::vector<int> predictions; 243 32 : std::vector<int> sample; 244 3564 : for (int row = 0; row < tsamples[0].size(); ++row) { 245 3540 : sample.clear(); 246 26252 : for (int col = 0; col < tsamples.size(); ++col) { 247 22712 : sample.push_back(tsamples[col][row]); 248 : } 249 3540 : std::vector<double> classProbabilities = predict_sample(sample); 250 : // Find the class with the maximum posterior probability 251 3532 : auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end()); 252 3532 : int predictedClass = distance(classProbabilities.begin(), maxElem); 253 3532 : predictions.push_back(predictedClass); 254 3532 : } 255 48 : return predictions; 256 40 : } 257 : // Return mxn std::vector of probabilities 258 : // tsamples is nxm std::vector of samples 259 552 : std::vector<std::vector<double>> Network::predict_proba(const std::vector<std::vector<int>>& tsamples) 260 : { 261 552 : if (!fitted) { 262 8 : throw std::logic_error("You must call fit() before calling predict_proba()"); 263 : } 264 544 : std::vector<std::vector<double>> predictions; 265 544 : std::vector<int> sample; 266 111516 : for (int row = 0; row < tsamples[0].size(); ++row) { 267 110972 : sample.clear(); 268 1055620 : for (int col = 0; col < tsamples.size(); ++col) { 269 944648 : sample.push_back(tsamples[col][row]); 270 : } 271 110972 : predictions.push_back(predict_sample(sample)); 272 : } 273 1088 : return predictions; 274 544 : } 275 20 : double Network::score(const std::vector<std::vector<int>>& tsamples, const std::vector<int>& labels) 276 : { 277 20 : std::vector<int> y_pred = predict(tsamples); 278 12 : int correct = 0; 279 2324 : for (int i = 0; i < y_pred.size(); ++i) { 280 2312 : if (y_pred[i] == labels[i]) { 281 1944 : correct++; 282 : } 283 : } 284 24 : return (double)correct / y_pred.size(); 285 12 : } 286 : // Return 1xn std::vector of probabilities 287 114512 : std::vector<double> Network::predict_sample(const std::vector<int>& sample) 288 : { 289 : // Ensure the sample size is equal to the number of features 290 114512 : if (sample.size() != features.size() - 1) { 291 16 : throw std::invalid_argument("Sample size (" + std::to_string(sample.size()) + 292 24 : ") does not match the number of features (" + std::to_string(features.size() - 1) + ")"); 293 : } 294 114504 : std::map<std::string, int> evidence; 295 1081840 : for (int i = 0; i < sample.size(); ++i) { 296 967336 : evidence[features[i]] = sample[i]; 297 : } 298 229008 : return exactInference(evidence); 299 114504 : } 300 : // Return 1xn std::vector of probabilities 301 781712 : std::vector<double> Network::predict_sample(const torch::Tensor& sample) 302 : { 303 : // Ensure the sample size is equal to the number of features 304 781712 : if (sample.size(0) != features.size() - 1) { 305 16 : throw std::invalid_argument("Sample size (" + std::to_string(sample.size(0)) + 306 24 : ") does not match the number of features (" + std::to_string(features.size() - 1) + ")"); 307 : } 308 781704 : std::map<std::string, int> evidence; 309 18085136 : for (int i = 0; i < sample.size(0); ++i) { 310 17303432 : evidence[features[i]] = sample[i].item<int>(); 311 : } 312 1563408 : return exactInference(evidence); 313 781704 : } 314 3734984 : double Network::computeFactor(std::map<std::string, int>& completeEvidence) 315 : { 316 3734984 : double result = 1.0; 317 72886736 : for (auto& node : getNodes()) { 318 69151752 : result *= node.second->getFactorValue(completeEvidence); 319 : } 320 3734984 : return result; 321 : } 322 896208 : std::vector<double> Network::exactInference(std::map<std::string, int>& evidence) 323 : { 324 896208 : std::vector<double> result(classNumStates, 0.0); 325 896208 : std::vector<std::thread> threads; 326 896208 : std::mutex mtx; 327 4631192 : for (int i = 0; i < classNumStates; ++i) { 328 3734984 : threads.emplace_back([this, &result, &evidence, i, &mtx]() { 329 3734984 : auto completeEvidence = std::map<std::string, int>(evidence); 330 3734984 : completeEvidence[getClassName()] = i; 331 3734984 : double factor = computeFactor(completeEvidence); 332 3734984 : std::lock_guard<std::mutex> lock(mtx); 333 3734984 : result[i] = factor; 334 3734984 : }); 335 : } 336 4631192 : for (auto& thread : threads) { 337 3734984 : thread.join(); 338 : } 339 : // Normalize result 340 896208 : double sum = accumulate(result.begin(), result.end(), 0.0); 341 4631192 : transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; }); 342 1792416 : return result; 343 896208 : } 344 28 : std::vector<std::string> Network::show() const 345 : { 346 28 : std::vector<std::string> result; 347 : // Draw the network 348 160 : for (auto& node : nodes) { 349 132 : std::string line = node.first + " -> "; 350 308 : for (auto child : node.second->getChildren()) { 351 176 : line += child->getName() + ", "; 352 : } 353 132 : result.push_back(line); 354 132 : } 355 56 : return result; 356 28 : } 357 112 : std::vector<std::string> Network::graph(const std::string& title) const 358 : { 359 112 : auto output = std::vector<std::string>(); 360 112 : auto prefix = "digraph BayesNet {\nlabel=<BayesNet "; 361 112 : auto suffix = ">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n"; 362 112 : std::string header = prefix + title + suffix; 363 112 : output.push_back(header); 364 844 : for (auto& node : nodes) { 365 732 : auto result = node.second->graph(className); 366 732 : output.insert(output.end(), result.begin(), result.end()); 367 732 : } 368 112 : output.push_back("}\n"); 369 224 : return output; 370 112 : } 371 408 : std::vector<std::pair<std::string, std::string>> Network::getEdges() const 372 : { 373 408 : auto edges = std::vector<std::pair<std::string, std::string>>(); 374 7396 : for (const auto& node : nodes) { 375 6988 : auto head = node.first; 376 20312 : for (const auto& child : node.second->getChildren()) { 377 13324 : auto tail = child->getName(); 378 13324 : edges.push_back({ head, tail }); 379 13324 : } 380 6988 : } 381 816 : return edges; 382 408 : } 383 364 : int Network::getNumEdges() const 384 : { 385 364 : return getEdges().size(); 386 : } 387 220 : std::vector<std::string> Network::topological_sort() 388 : { 389 : /* Check if al the fathers of every node are before the node */ 390 220 : auto result = features; 391 220 : result.erase(remove(result.begin(), result.end(), className), result.end()); 392 220 : bool ending{ false }; 393 628 : while (!ending) { 394 408 : ending = true; 395 3804 : for (auto feature : features) { 396 3396 : auto fathers = nodes[feature]->getParents(); 397 9000 : for (const auto& father : fathers) { 398 5604 : auto fatherName = father->getName(); 399 5604 : if (fatherName == className) { 400 2980 : continue; 401 : } 402 : // Check if father is placed before the actual feature 403 2624 : auto it = find(result.begin(), result.end(), fatherName); 404 2624 : if (it != result.end()) { 405 2624 : auto it2 = find(result.begin(), result.end(), feature); 406 2624 : if (it2 != result.end()) { 407 5248 : if (distance(it, it2) < 0) { 408 : // if it is not, insert it before the feature 409 244 : result.erase(remove(result.begin(), result.end(), fatherName), result.end()); 410 244 : result.insert(it2, fatherName); 411 244 : ending = false; 412 : } 413 : } 414 : } 415 5604 : } 416 3396 : } 417 : } 418 440 : return result; 419 220 : } 420 8 : std::string Network::dump_cpt() const 421 : { 422 8 : std::stringstream oss; 423 48 : for (auto& node : nodes) { 424 40 : oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl; 425 40 : oss << node.second->getCPT() << std::endl; 426 : } 427 16 : return oss.str(); 428 8 : } 429 : } |
![]() |
Generated by: LCOV version 2.0-1 |
</html>