Files
BayesNet/html/bayesnet/network/Network.cc.gcov.html
2024-04-30 00:52:09 +02:00

63 KiB

<html lang="en"> <head> </head>
LCOV - code coverage report
Current view: top level - bayesnet/network - Network.cc (source / functions) Coverage Total Hit
Test: coverage.info Lines: 97.6 % 297 290
Test Date: 2024-04-29 20:48:03 Functions: 100.0 % 40 40

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include <thread>
       8              : #include <mutex>
       9              : #include <sstream>
      10              : #include "Network.h"
      11              : #include "bayesnet/utils/bayesnetUtils.h"
      12              : namespace bayesnet {
      13         4992 :     Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
      14              :     {
      15         4992 :     }
      16           22 :     Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
      17              :     {
      18              : 
      19           22 :     }
      20         4761 :     Network::Network(const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
      21         9522 :         maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
      22              :     {
      23         4761 :         if (samples.defined())
      24           11 :             samples = samples.clone();
      25         4816 :         for (const auto& node : other.nodes) {
      26           55 :             nodes[node.first] = std::make_unique<Node>(*node.second);
      27              :         }
      28         4761 :     }
      29         3358 :     void Network::initialize()
      30              :     {
      31         3358 :         features.clear();
      32         3358 :         className = "";
      33         3358 :         classNumStates = 0;
      34         3358 :         fitted = false;
      35         3358 :         nodes.clear();
      36         3358 :         samples = torch::Tensor();
      37         3358 :     }
      38         4794 :     float Network::getMaxThreads() const
      39              :     {
      40         4794 :         return maxThreads;
      41              :     }
      42          132 :     torch::Tensor& Network::getSamples()
      43              :     {
      44          132 :         return samples;
      45              :     }
      46       116878 :     void Network::addNode(const std::string& name)
      47              :     {
      48       116878 :         if (name == "") {
      49           22 :             throw std::invalid_argument("Node name cannot be empty");
      50              :         }
      51       116856 :         if (nodes.find(name) != nodes.end()) {
      52            0 :             return;
      53              :         }
      54       116856 :         if (find(features.begin(), features.end(), name) == features.end()) {
      55       116856 :             features.push_back(name);
      56              :         }
      57       116856 :         nodes[name] = std::make_unique<Node>(name);
      58              :     }
      59          607 :     std::vector<std::string> Network::getFeatures() const
      60              :     {
      61          607 :         return features;
      62              :     }
      63         5704 :     int Network::getClassNumStates() const
      64              :     {
      65         5704 :         return classNumStates;
      66              :     }
      67          132 :     int Network::getStates() const
      68              :     {
      69          132 :         int result = 0;
      70          792 :         for (auto& node : nodes) {
      71          660 :             result += node.second->getNumStates();
      72              :         }
      73          132 :         return result;
      74              :     }
      75      5150624 :     std::string Network::getClassName() const
      76              :     {
      77      5150624 :         return className;
      78              :     }
      79       295830 :     bool Network::isCyclic(const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack)
      80              :     {
      81       295830 :         if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
      82              :         {
      83       295830 :             visited.insert(nodeId);
      84       295830 :             recStack.insert(nodeId);
      85       367384 :             for (Node* child : nodes[nodeId]->getChildren()) {
      86        71620 :                 if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
      87           66 :                     return true;
      88        71576 :                 if (recStack.find(child->getName()) != recStack.end())
      89           22 :                     return true;
      90              :             }
      91              :         }
      92       295764 :         recStack.erase(nodeId); // remove node from recursion stack before function ends
      93       295764 :         return false;
      94              :     }
      95       224276 :     void Network::addEdge(const std::string& parent, const std::string& child)
      96              :     {
      97       224276 :         if (nodes.find(parent) == nodes.end()) {
      98           22 :             throw std::invalid_argument("Parent node " + parent + " does not exist");
      99              :         }
     100       224254 :         if (nodes.find(child) == nodes.end()) {
     101           22 :             throw std::invalid_argument("Child node " + child + " does not exist");
     102              :         }
     103              :         // Temporarily add edge to check for cycles
     104       224232 :         nodes[parent]->addChild(nodes[child].get());
     105       224232 :         nodes[child]->addParent(nodes[parent].get());
     106       224232 :         std::unordered_set<std::string> visited;
     107       224232 :         std::unordered_set<std::string> recStack;
     108       224232 :         if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
     109              :         {
     110              :             // remove problematic edge
     111           22 :             nodes[parent]->removeChild(nodes[child].get());
     112           22 :             nodes[child]->removeParent(nodes[parent].get());
     113           22 :             throw std::invalid_argument("Adding this edge forms a cycle in the graph.");
     114              :         }
     115       224254 :     }
     116      5151361 :     std::map<std::string, std::unique_ptr<Node>>& Network::getNodes()
     117              :     {
     118      5151361 :         return nodes;
     119              :     }
     120         3787 :     void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
     121              :     {
     122         3787 :         if (weights.size(0) != n_samples) {
     123           22 :             throw std::invalid_argument("Weights (" + std::to_string(weights.size(0)) + ") must have the same number of elements as samples (" + std::to_string(n_samples) + ") in Network::fit");
     124              :         }
     125         3765 :         if (n_samples != n_samples_y) {
     126           22 :             throw std::invalid_argument("X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) + " != " + std::to_string(n_samples_y) + ")");
     127              :         }
     128         3743 :         if (n_features != featureNames.size()) {
     129           22 :             throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
     130              :         }
     131         3721 :         if (features.size() == 0) {
     132           22 :             throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()");
     133              :         }
     134         3699 :         if (n_features != features.size() - 1) {
     135           22 :             throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
     136              :         }
     137         3677 :         if (find(features.begin(), features.end(), className) == features.end()) {
     138           22 :             throw std::invalid_argument("Class Name not found in Network::features");
     139              :         }
     140       121476 :         for (auto& feature : featureNames) {
     141       117843 :             if (find(features.begin(), features.end(), feature) == features.end()) {
     142           22 :                 throw std::invalid_argument("Feature " + feature + " not found in Network::features");
     143              :             }
     144       117821 :             if (states.find(feature) == states.end()) {
     145            0 :                 throw std::invalid_argument("Feature " + feature + " not found in states");
     146              :             }
     147              :         }
     148         3633 :     }
     149         3633 :     void Network::setStates(const std::map<std::string, std::vector<int>>& states)
     150              :     {
     151              :         // Set states to every Node in the network
     152         3633 :         for_each(features.begin(), features.end(), [this, &states](const std::string& feature) {
     153       121388 :             nodes.at(feature)->setNumStates(states.at(feature).size());
     154       121388 :             });
     155         3633 :         classNumStates = nodes.at(className)->getNumStates();
     156         3633 :     }
     157              :     // X comes in nxm, where n is the number of features and m the number of samples
     158           11 :     void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     159              :     {
     160           11 :         checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
     161           11 :         this->className = className;
     162           11 :         torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
     163           33 :         samples = torch::cat({ X , ytmp }, 0);
     164           55 :         for (int i = 0; i < featureNames.size(); ++i) {
     165          132 :             auto row_feature = X.index({ i, "..." });
     166           44 :         }
     167           11 :         completeFit(states, weights);
     168           66 :     }
     169         3545 :     void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     170              :     {
     171         3545 :         checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
     172         3545 :         this->className = className;
     173         3545 :         this->samples = samples;
     174         3545 :         completeFit(states, weights);
     175         3545 :     }
     176              :     // input_data comes in nxm, where n is the number of features and m the number of samples
     177          231 :     void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     178              :     {
     179          231 :         const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
     180          231 :         checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
     181           77 :         this->className = className;
     182              :         // Build tensor of samples (nxm) (n+1 because of the class)
     183           77 :         samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32);
     184          385 :         for (int i = 0; i < featureNames.size(); ++i) {
     185         1232 :             samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
     186              :         }
     187          308 :         samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
     188           77 :         completeFit(states, weights);
     189          616 :     }
     190         3633 :     void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
     191              :     {
     192         3633 :         setStates(states);
     193         3633 :         laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
     194         3633 :         std::vector<std::thread> threads;
     195       125021 :         for (auto& node : nodes) {
     196       121388 :             threads.emplace_back([this, &node, &weights]() {
     197       121388 :                 node.second->computeCPT(samples, features, laplaceSmoothing, weights);
     198       121388 :                 });
     199              :         }
     200       125021 :         for (auto& thread : threads) {
     201       121388 :             thread.join();
     202              :         }
     203         3633 :         fitted = true;
     204         3633 :     }
     205         6802 :     torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
     206              :     {
     207         6802 :         if (!fitted) {
     208           22 :             throw std::logic_error("You must call fit() before calling predict()");
     209              :         }
     210         6780 :         torch::Tensor result;
     211         6780 :         result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
     212      1170049 :         for (int i = 0; i < samples.size(1); ++i) {
     213      3489873 :             const torch::Tensor sample = samples.index({ "...", i });
     214      1163291 :             auto psample = predict_sample(sample);
     215      1163269 :             auto temp = torch::tensor(psample, torch::kFloat64);
     216              :             //            result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64));
     217      3489807 :             result.index_put_({ i, "..." }, temp);
     218      1163291 :         }
     219         6758 :         if (proba)
     220         3540 :             return result;
     221         6436 :         return result.argmax(1);
     222      2333340 :     }
     223              :     // Return mxn tensor of probabilities
     224         3540 :     torch::Tensor Network::predict_proba(const torch::Tensor& samples)
     225              :     {
     226         3540 :         return predict_tensor(samples, true);
     227              :     }
     228              : 
     229              :     // Return mxn tensor of probabilities
     230         3262 :     torch::Tensor Network::predict(const torch::Tensor& samples)
     231              :     {
     232         3262 :         return predict_tensor(samples, false);
     233              :     }
     234              : 
     235              :     // Return mx1 std::vector of predictions
     236              :     // tsamples is nxm std::vector of samples
     237          132 :     std::vector<int> Network::predict(const std::vector<std::vector<int>>& tsamples)
     238              :     {
     239          132 :         if (!fitted) {
     240           44 :             throw std::logic_error("You must call fit() before calling predict()");
     241              :         }
     242           88 :         std::vector<int> predictions;
     243           88 :         std::vector<int> sample;
     244         9801 :         for (int row = 0; row < tsamples[0].size(); ++row) {
     245         9735 :             sample.clear();
     246        72193 :             for (int col = 0; col < tsamples.size(); ++col) {
     247        62458 :                 sample.push_back(tsamples[col][row]);
     248              :             }
     249         9735 :             std::vector<double> classProbabilities = predict_sample(sample);
     250              :             // Find the class with the maximum posterior probability
     251         9713 :             auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
     252         9713 :             int predictedClass = distance(classProbabilities.begin(), maxElem);
     253         9713 :             predictions.push_back(predictedClass);
     254         9713 :         }
     255          132 :         return predictions;
     256          110 :     }
     257              :     // Return mxn std::vector of probabilities
     258              :     // tsamples is nxm std::vector of samples
     259          777 :     std::vector<std::vector<double>> Network::predict_proba(const std::vector<std::vector<int>>& tsamples)
     260              :     {
     261          777 :         if (!fitted) {
     262           22 :             throw std::logic_error("You must call fit() before calling predict_proba()");
     263              :         }
     264          755 :         std::vector<std::vector<double>> predictions;
     265          755 :         std::vector<int> sample;
     266       146506 :         for (int row = 0; row < tsamples[0].size(); ++row) {
     267       145751 :             sample.clear();
     268      1941951 :             for (int col = 0; col < tsamples.size(); ++col) {
     269      1796200 :                 sample.push_back(tsamples[col][row]);
     270              :             }
     271       145751 :             predictions.push_back(predict_sample(sample));
     272              :         }
     273         1510 :         return predictions;
     274          755 :     }
     275           55 :     double Network::score(const std::vector<std::vector<int>>& tsamples, const std::vector<int>& labels)
     276              :     {
     277           55 :         std::vector<int> y_pred = predict(tsamples);
     278           33 :         int correct = 0;
     279         6391 :         for (int i = 0; i < y_pred.size(); ++i) {
     280         6358 :             if (y_pred[i] == labels[i]) {
     281         5346 :                 correct++;
     282              :             }
     283              :         }
     284           66 :         return (double)correct / y_pred.size();
     285           33 :     }
     286              :     // Return 1xn std::vector of probabilities
     287       155486 :     std::vector<double> Network::predict_sample(const std::vector<int>& sample)
     288              :     {
     289              :         // Ensure the sample size is equal to the number of features
     290       155486 :         if (sample.size() != features.size() - 1) {
     291           44 :             throw std::invalid_argument("Sample size (" + std::to_string(sample.size()) +
     292           66 :                 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
     293              :         }
     294       155464 :         std::map<std::string, int> evidence;
     295      2014056 :         for (int i = 0; i < sample.size(); ++i) {
     296      1858592 :             evidence[features[i]] = sample[i];
     297              :         }
     298       310928 :         return exactInference(evidence);
     299       155464 :     }
     300              :     // Return 1xn std::vector of probabilities
     301      1163291 :     std::vector<double> Network::predict_sample(const torch::Tensor& sample)
     302              :     {
     303              :         // Ensure the sample size is equal to the number of features
     304      1163291 :         if (sample.size(0) != features.size() - 1) {
     305           44 :             throw std::invalid_argument("Sample size (" + std::to_string(sample.size(0)) +
     306           66 :                 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
     307              :         }
     308      1163269 :         std::map<std::string, int> evidence;
     309     30202277 :         for (int i = 0; i < sample.size(0); ++i) {
     310     29039008 :             evidence[features[i]] = sample[i].item<int>();
     311              :         }
     312      2326538 :         return exactInference(evidence);
     313      1163269 :     }
     314      5150558 :     double Network::computeFactor(std::map<std::string, int>& completeEvidence)
     315              :     {
     316      5150558 :         double result = 1.0;
     317     72453396 :         for (auto& node : getNodes()) {
     318     67302838 :             result *= node.second->getFactorValue(completeEvidence);
     319              :         }
     320      5150558 :         return result;
     321              :     }
     322      1318733 :     std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
     323              :     {
     324      1318733 :         std::vector<double> result(classNumStates, 0.0);
     325      1318733 :         std::vector<std::thread> threads;
     326      1318733 :         std::mutex mtx;
     327      6469291 :         for (int i = 0; i < classNumStates; ++i) {
     328      5150558 :             threads.emplace_back([this, &result, &evidence, i, &mtx]() {
     329      5150558 :                 auto completeEvidence = std::map<std::string, int>(evidence);
     330      5150558 :                 completeEvidence[getClassName()] = i;
     331      5150558 :                 double factor = computeFactor(completeEvidence);
     332      5150558 :                 std::lock_guard<std::mutex> lock(mtx);
     333      5150558 :                 result[i] = factor;
     334      5150558 :                 });
     335              :         }
     336      6469291 :         for (auto& thread : threads) {
     337      5150558 :             thread.join();
     338              :         }
     339              :         // Normalize result
     340      1318733 :         double sum = accumulate(result.begin(), result.end(), 0.0);
     341      6469291 :         transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; });
     342      2637466 :         return result;
     343      1318733 :     }
     344           77 :     std::vector<std::string> Network::show() const
     345              :     {
     346           77 :         std::vector<std::string> result;
     347              :         // Draw the network
     348          440 :         for (auto& node : nodes) {
     349          363 :             std::string line = node.first + " -> ";
     350          847 :             for (auto child : node.second->getChildren()) {
     351          484 :                 line += child->getName() + ", ";
     352              :             }
     353          363 :             result.push_back(line);
     354          363 :         }
     355           77 :         return result;
     356            0 :     }
     357          242 :     std::vector<std::string> Network::graph(const std::string& title) const
     358              :     {
     359          242 :         auto output = std::vector<std::string>();
     360          242 :         auto prefix = "digraph BayesNet {\nlabel=<BayesNet ";
     361          242 :         auto suffix = ">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
     362          242 :         std::string header = prefix + title + suffix;
     363          242 :         output.push_back(header);
     364         1925 :         for (auto& node : nodes) {
     365         1683 :             auto result = node.second->graph(className);
     366         1683 :             output.insert(output.end(), result.begin(), result.end());
     367         1683 :         }
     368          242 :         output.push_back("}\n");
     369          484 :         return output;
     370          242 :     }
     371          684 :     std::vector<std::pair<std::string, std::string>> Network::getEdges() const
     372              :     {
     373          684 :         auto edges = std::vector<std::pair<std::string, std::string>>();
     374        10684 :         for (const auto& node : nodes) {
     375        10000 :             auto head = node.first;
     376        27937 :             for (const auto& child : node.second->getChildren()) {
     377        17937 :                 auto tail = child->getName();
     378        17937 :                 edges.push_back({ head, tail });
     379        17937 :             }
     380        10000 :         }
     381          684 :         return edges;
     382            0 :     }
     383          563 :     int Network::getNumEdges() const
     384              :     {
     385          563 :         return getEdges().size();
     386              :     }
     387          605 :     std::vector<std::string> Network::topological_sort()
     388              :     {
     389              :         /* Check if al the fathers of every node are before the node */
     390          605 :         auto result = features;
     391          605 :         result.erase(remove(result.begin(), result.end(), className), result.end());
     392          605 :         bool ending{ false };
     393         1727 :         while (!ending) {
     394         1122 :             ending = true;
     395        10461 :             for (auto feature : features) {
     396         9339 :                 auto fathers = nodes[feature]->getParents();
     397        24750 :                 for (const auto& father : fathers) {
     398        15411 :                     auto fatherName = father->getName();
     399        15411 :                     if (fatherName == className) {
     400         8195 :                         continue;
     401              :                     }
     402              :                     // Check if father is placed before the actual feature
     403         7216 :                     auto it = find(result.begin(), result.end(), fatherName);
     404         7216 :                     if (it != result.end()) {
     405         7216 :                         auto it2 = find(result.begin(), result.end(), feature);
     406         7216 :                         if (it2 != result.end()) {
     407         7216 :                             if (distance(it, it2) < 0) {
     408              :                                 // if it is not, insert it before the feature
     409          671 :                                 result.erase(remove(result.begin(), result.end(), fatherName), result.end());
     410          671 :                                 result.insert(it2, fatherName);
     411          671 :                                 ending = false;
     412              :                             }
     413              :                         } else {
     414            0 :                             throw std::logic_error("Error in topological sort because of node " + feature + " is not in result");
     415              :                         }
     416              :                     } else {
     417            0 :                         throw std::logic_error("Error in topological sort because of node father " + fatherName + " is not in result");
     418              :                     }
     419        15411 :                 }
     420         9339 :             }
     421              :         }
     422          605 :         return result;
     423            0 :     }
     424           22 :     std::string Network::dump_cpt() const
     425              :     {
     426           22 :         std::stringstream oss;
     427          132 :         for (auto& node : nodes) {
     428          110 :             oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
     429          110 :             oss << node.second->getCPT() << std::endl;
     430              :         }
     431           44 :         return oss.str();
     432           22 :     }
     433              : }
        

Generated by: LCOV version 2.0-1

</html>