Files
BayesNet/html/bayesnet/network/Network.cc.gcov.html

63 KiB

<html lang="en"> <head> </head>
LCOV - code coverage report
Current view: top level - bayesnet/network - Network.cc (source / functions) Coverage Total Hit
Test: coverage.info Lines: 98.3 % 295 290
Test Date: 2024-04-30 20:26:57 Functions: 100.0 % 40 40

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include <thread>
       8              : #include <mutex>
       9              : #include <sstream>
      10              : #include "Network.h"
      11              : #include "bayesnet/utils/bayesnetUtils.h"
      12              : namespace bayesnet {
      13          930 :     Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
      14              :     {
      15          930 :     }
      16            4 :     Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
      17              :     {
      18              : 
      19            4 :     }
      20          888 :     Network::Network(const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
      21         1776 :         maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
      22              :     {
      23          888 :         if (samples.defined())
      24            2 :             samples = samples.clone();
      25          898 :         for (const auto& node : other.nodes) {
      26           10 :             nodes[node.first] = std::make_unique<Node>(*node.second);
      27              :         }
      28          888 :     }
      29          634 :     void Network::initialize()
      30              :     {
      31          634 :         features.clear();
      32          634 :         className = "";
      33          634 :         classNumStates = 0;
      34          634 :         fitted = false;
      35          634 :         nodes.clear();
      36          634 :         samples = torch::Tensor();
      37          634 :     }
      38          894 :     float Network::getMaxThreads() const
      39              :     {
      40          894 :         return maxThreads;
      41              :     }
      42           24 :     torch::Tensor& Network::getSamples()
      43              :     {
      44           24 :         return samples;
      45              :     }
      46        13374 :     void Network::addNode(const std::string& name)
      47              :     {
      48        13374 :         if (name == "") {
      49            4 :             throw std::invalid_argument("Node name cannot be empty");
      50              :         }
      51        13370 :         if (nodes.find(name) != nodes.end()) {
      52            0 :             return;
      53              :         }
      54        13370 :         if (find(features.begin(), features.end(), name) == features.end()) {
      55        13370 :             features.push_back(name);
      56              :         }
      57        13370 :         nodes[name] = std::make_unique<Node>(name);
      58              :     }
      59          118 :     std::vector<std::string> Network::getFeatures() const
      60              :     {
      61          118 :         return features;
      62              :     }
      63         1070 :     int Network::getClassNumStates() const
      64              :     {
      65         1070 :         return classNumStates;
      66              :     }
      67           24 :     int Network::getStates() const
      68              :     {
      69           24 :         int result = 0;
      70          144 :         for (auto& node : nodes) {
      71          120 :             result += node.second->getNumStates();
      72              :         }
      73           24 :         return result;
      74              :     }
      75      1590160 :     std::string Network::getClassName() const
      76              :     {
      77      1590160 :         return className;
      78              :     }
      79        30532 :     bool Network::isCyclic(const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack)
      80              :     {
      81        30532 :         if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
      82              :         {
      83        30532 :             visited.insert(nodeId);
      84        30532 :             recStack.insert(nodeId);
      85        36110 :             for (Node* child : nodes[nodeId]->getChildren()) {
      86         5590 :                 if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
      87           12 :                     return true;
      88         5582 :                 if (recStack.find(child->getName()) != recStack.end())
      89            4 :                     return true;
      90              :             }
      91              :         }
      92        30520 :         recStack.erase(nodeId); // remove node from recursion stack before function ends
      93        30520 :         return false;
      94              :     }
      95        24954 :     void Network::addEdge(const std::string& parent, const std::string& child)
      96              :     {
      97        24954 :         if (nodes.find(parent) == nodes.end()) {
      98            4 :             throw std::invalid_argument("Parent node " + parent + " does not exist");
      99              :         }
     100        24950 :         if (nodes.find(child) == nodes.end()) {
     101            4 :             throw std::invalid_argument("Child node " + child + " does not exist");
     102              :         }
     103              :         // Temporarily add edge to check for cycles
     104        24946 :         nodes[parent]->addChild(nodes[child].get());
     105        24946 :         nodes[child]->addParent(nodes[parent].get());
     106        24946 :         std::unordered_set<std::string> visited;
     107        24946 :         std::unordered_set<std::string> recStack;
     108        24946 :         if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
     109              :         {
     110              :             // remove problematic edge
     111            4 :             nodes[parent]->removeChild(nodes[child].get());
     112            4 :             nodes[child]->removeParent(nodes[parent].get());
     113            4 :             throw std::invalid_argument("Adding this edge forms a cycle in the graph.");
     114              :         }
     115        24950 :     }
     116      1590294 :     std::map<std::string, std::unique_ptr<Node>>& Network::getNodes()
     117              :     {
     118      1590294 :         return nodes;
     119              :     }
     120          712 :     void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
     121              :     {
     122          712 :         if (weights.size(0) != n_samples) {
     123            4 :             throw std::invalid_argument("Weights (" + std::to_string(weights.size(0)) + ") must have the same number of elements as samples (" + std::to_string(n_samples) + ") in Network::fit");
     124              :         }
     125          708 :         if (n_samples != n_samples_y) {
     126            4 :             throw std::invalid_argument("X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) + " != " + std::to_string(n_samples_y) + ")");
     127              :         }
     128          704 :         if (n_features != featureNames.size()) {
     129            4 :             throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
     130              :         }
     131          700 :         if (features.size() == 0) {
     132            4 :             throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()");
     133              :         }
     134          696 :         if (n_features != features.size() - 1) {
     135            4 :             throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
     136              :         }
     137          692 :         if (find(features.begin(), features.end(), className) == features.end()) {
     138            4 :             throw std::invalid_argument("Class Name not found in Network::features");
     139              :         }
     140        14210 :         for (auto& feature : featureNames) {
     141        13526 :             if (find(features.begin(), features.end(), feature) == features.end()) {
     142            4 :                 throw std::invalid_argument("Feature " + feature + " not found in Network::features");
     143              :             }
     144        13522 :             if (states.find(feature) == states.end()) {
     145            0 :                 throw std::invalid_argument("Feature " + feature + " not found in states");
     146              :             }
     147              :         }
     148          684 :     }
     149          684 :     void Network::setStates(const std::map<std::string, std::vector<int>>& states)
     150              :     {
     151              :         // Set states to every Node in the network
     152          684 :         for_each(features.begin(), features.end(), [this, &states](const std::string& feature) {
     153        14194 :             nodes.at(feature)->setNumStates(states.at(feature).size());
     154        14194 :             });
     155          684 :         classNumStates = nodes.at(className)->getNumStates();
     156          684 :     }
     157              :     // X comes in nxm, where n is the number of features and m the number of samples
     158            2 :     void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     159              :     {
     160            2 :         checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
     161            2 :         this->className = className;
     162            2 :         torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
     163            6 :         samples = torch::cat({ X , ytmp }, 0);
     164           10 :         for (int i = 0; i < featureNames.size(); ++i) {
     165           24 :             auto row_feature = X.index({ i, "..." });
     166            8 :         }
     167            2 :         completeFit(states, weights);
     168           12 :     }
     169          668 :     void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     170              :     {
     171          668 :         checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
     172          668 :         this->className = className;
     173          668 :         this->samples = samples;
     174          668 :         completeFit(states, weights);
     175          668 :     }
     176              :     // input_data comes in nxm, where n is the number of features and m the number of samples
     177           42 :     void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     178              :     {
     179           42 :         const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
     180           42 :         checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
     181           14 :         this->className = className;
     182              :         // Build tensor of samples (nxm) (n+1 because of the class)
     183           14 :         samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32);
     184           70 :         for (int i = 0; i < featureNames.size(); ++i) {
     185          224 :             samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
     186              :         }
     187           56 :         samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
     188           14 :         completeFit(states, weights);
     189          112 :     }
     190          684 :     void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
     191              :     {
     192          684 :         setStates(states);
     193          684 :         laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
     194          684 :         std::vector<std::thread> threads;
     195        14878 :         for (auto& node : nodes) {
     196        14194 :             threads.emplace_back([this, &node, &weights]() {
     197        14194 :                 node.second->computeCPT(samples, features, laplaceSmoothing, weights);
     198        14194 :                 });
     199              :         }
     200        14878 :         for (auto& thread : threads) {
     201        14194 :             thread.join();
     202              :         }
     203          684 :         fitted = true;
     204          684 :     }
     205         1588 :     torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
     206              :     {
     207         1588 :         if (!fitted) {
     208            4 :             throw std::logic_error("You must call fit() before calling predict()");
     209              :         }
     210         1584 :         torch::Tensor result;
     211         1584 :         result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
     212       377028 :         for (int i = 0; i < samples.size(1); ++i) {
     213      1126344 :             const torch::Tensor sample = samples.index({ "...", i });
     214       375448 :             auto psample = predict_sample(sample);
     215       375444 :             auto temp = torch::tensor(psample, torch::kFloat64);
     216              :             //            result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64));
     217      1126332 :             result.index_put_({ i, "..." }, temp);
     218       375448 :         }
     219         1580 :         if (proba)
     220          738 :             return result;
     221         1684 :         return result.argmax(1);
     222       752476 :     }
     223              :     // Return mxn tensor of probabilities
     224          738 :     torch::Tensor Network::predict_proba(const torch::Tensor& samples)
     225              :     {
     226          738 :         return predict_tensor(samples, true);
     227              :     }
     228              : 
     229              :     // Return mxn tensor of probabilities
     230          850 :     torch::Tensor Network::predict(const torch::Tensor& samples)
     231              :     {
     232          850 :         return predict_tensor(samples, false);
     233              :     }
     234              : 
     235              :     // Return mx1 std::vector of predictions
     236              :     // tsamples is nxm std::vector of samples
     237           24 :     std::vector<int> Network::predict(const std::vector<std::vector<int>>& tsamples)
     238              :     {
     239           24 :         if (!fitted) {
     240            8 :             throw std::logic_error("You must call fit() before calling predict()");
     241              :         }
     242           16 :         std::vector<int> predictions;
     243           16 :         std::vector<int> sample;
     244         1782 :         for (int row = 0; row < tsamples[0].size(); ++row) {
     245         1770 :             sample.clear();
     246        13126 :             for (int col = 0; col < tsamples.size(); ++col) {
     247        11356 :                 sample.push_back(tsamples[col][row]);
     248              :             }
     249         1770 :             std::vector<double> classProbabilities = predict_sample(sample);
     250              :             // Find the class with the maximum posterior probability
     251         1766 :             auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
     252         1766 :             int predictedClass = distance(classProbabilities.begin(), maxElem);
     253         1766 :             predictions.push_back(predictedClass);
     254         1766 :         }
     255           24 :         return predictions;
     256           20 :     }
     257              :     // Return mxn std::vector of probabilities
     258              :     // tsamples is nxm std::vector of samples
     259          132 :     std::vector<std::vector<double>> Network::predict_proba(const std::vector<std::vector<int>>& tsamples)
     260              :     {
     261          132 :         if (!fitted) {
     262            4 :             throw std::logic_error("You must call fit() before calling predict_proba()");
     263              :         }
     264          128 :         std::vector<std::vector<double>> predictions;
     265          128 :         std::vector<int> sample;
     266        24798 :         for (int row = 0; row < tsamples[0].size(); ++row) {
     267        24670 :             sample.clear();
     268       219650 :             for (int col = 0; col < tsamples.size(); ++col) {
     269       194980 :                 sample.push_back(tsamples[col][row]);
     270              :             }
     271        24670 :             predictions.push_back(predict_sample(sample));
     272              :         }
     273          256 :         return predictions;
     274          128 :     }
     275           10 :     double Network::score(const std::vector<std::vector<int>>& tsamples, const std::vector<int>& labels)
     276              :     {
     277           10 :         std::vector<int> y_pred = predict(tsamples);
     278            6 :         int correct = 0;
     279         1162 :         for (int i = 0; i < y_pred.size(); ++i) {
     280         1156 :             if (y_pred[i] == labels[i]) {
     281          972 :                 correct++;
     282              :             }
     283              :         }
     284           12 :         return (double)correct / y_pred.size();
     285            6 :     }
     286              :     // Return 1xn std::vector of probabilities
     287        26440 :     std::vector<double> Network::predict_sample(const std::vector<int>& sample)
     288              :     {
     289              :         // Ensure the sample size is equal to the number of features
     290        26440 :         if (sample.size() != features.size() - 1) {
     291            8 :             throw std::invalid_argument("Sample size (" + std::to_string(sample.size()) +
     292           12 :                 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
     293              :         }
     294        26436 :         std::map<std::string, int> evidence;
     295       232760 :         for (int i = 0; i < sample.size(); ++i) {
     296       206324 :             evidence[features[i]] = sample[i];
     297              :         }
     298        52872 :         return exactInference(evidence);
     299        26436 :     }
     300              :     // Return 1xn std::vector of probabilities
     301       375448 :     std::vector<double> Network::predict_sample(const torch::Tensor& sample)
     302              :     {
     303              :         // Ensure the sample size is equal to the number of features
     304       375448 :         if (sample.size(0) != features.size() - 1) {
     305            8 :             throw std::invalid_argument("Sample size (" + std::to_string(sample.size(0)) +
     306           12 :                 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
     307              :         }
     308       375444 :         std::map<std::string, int> evidence;
     309      8888488 :         for (int i = 0; i < sample.size(0); ++i) {
     310      8513044 :             evidence[features[i]] = sample[i].item<int>();
     311              :         }
     312       750888 :         return exactInference(evidence);
     313       375444 :     }
     314      1590148 :     double Network::computeFactor(std::map<std::string, int>& completeEvidence)
     315              :     {
     316      1590148 :         double result = 1.0;
     317     33392584 :         for (auto& node : getNodes()) {
     318     31802436 :             result *= node.second->getFactorValue(completeEvidence);
     319              :         }
     320      1590148 :         return result;
     321              :     }
     322       401880 :     std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
     323              :     {
     324       401880 :         std::vector<double> result(classNumStates, 0.0);
     325       401880 :         std::vector<std::thread> threads;
     326       401880 :         std::mutex mtx;
     327      1992028 :         for (int i = 0; i < classNumStates; ++i) {
     328      1590148 :             threads.emplace_back([this, &result, &evidence, i, &mtx]() {
     329      1590148 :                 auto completeEvidence = std::map<std::string, int>(evidence);
     330      1590148 :                 completeEvidence[getClassName()] = i;
     331      1590148 :                 double factor = computeFactor(completeEvidence);
     332      1590148 :                 std::lock_guard<std::mutex> lock(mtx);
     333      1590148 :                 result[i] = factor;
     334      1590148 :                 });
     335              :         }
     336      1992028 :         for (auto& thread : threads) {
     337      1590148 :             thread.join();
     338              :         }
     339              :         // Normalize result
     340       401880 :         double sum = accumulate(result.begin(), result.end(), 0.0);
     341      1992028 :         transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; });
     342       803760 :         return result;
     343       401880 :     }
     344           14 :     std::vector<std::string> Network::show() const
     345              :     {
     346           14 :         std::vector<std::string> result;
     347              :         // Draw the network
     348           80 :         for (auto& node : nodes) {
     349           66 :             std::string line = node.first + " -> ";
     350          154 :             for (auto child : node.second->getChildren()) {
     351           88 :                 line += child->getName() + ", ";
     352              :             }
     353           66 :             result.push_back(line);
     354           66 :         }
     355           14 :         return result;
     356            0 :     }
     357           44 :     std::vector<std::string> Network::graph(const std::string& title) const
     358              :     {
     359           44 :         auto output = std::vector<std::string>();
     360           44 :         auto prefix = "digraph BayesNet {\nlabel=<BayesNet ";
     361           44 :         auto suffix = ">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
     362           44 :         std::string header = prefix + title + suffix;
     363           44 :         output.push_back(header);
     364          350 :         for (auto& node : nodes) {
     365          306 :             auto result = node.second->graph(className);
     366          306 :             output.insert(output.end(), result.begin(), result.end());
     367          306 :         }
     368           44 :         output.push_back("}\n");
     369           88 :         return output;
     370           44 :     }
     371          132 :     std::vector<std::pair<std::string, std::string>> Network::getEdges() const
     372              :     {
     373          132 :         auto edges = std::vector<std::pair<std::string, std::string>>();
     374         2906 :         for (const auto& node : nodes) {
     375         2774 :             auto head = node.first;
     376         7924 :             for (const auto& child : node.second->getChildren()) {
     377         5150 :                 auto tail = child->getName();
     378         5150 :                 edges.push_back({ head, tail });
     379         5150 :             }
     380         2774 :         }
     381          132 :         return edges;
     382            0 :     }
     383          110 :     int Network::getNumEdges() const
     384              :     {
     385          110 :         return getEdges().size();
     386              :     }
     387          110 :     std::vector<std::string> Network::topological_sort()
     388              :     {
     389              :         /* Check if al the fathers of every node are before the node */
     390          110 :         auto result = features;
     391          110 :         result.erase(remove(result.begin(), result.end(), className), result.end());
     392          110 :         bool ending{ false };
     393          314 :         while (!ending) {
     394          204 :             ending = true;
     395         1902 :             for (auto feature : features) {
     396         1698 :                 auto fathers = nodes[feature]->getParents();
     397         4500 :                 for (const auto& father : fathers) {
     398         2802 :                     auto fatherName = father->getName();
     399         2802 :                     if (fatherName == className) {
     400         1490 :                         continue;
     401              :                     }
     402              :                     // Check if father is placed before the actual feature
     403         1312 :                     auto it = find(result.begin(), result.end(), fatherName);
     404         1312 :                     if (it != result.end()) {
     405         1312 :                         auto it2 = find(result.begin(), result.end(), feature);
     406         1312 :                         if (it2 != result.end()) {
     407         1312 :                             if (distance(it, it2) < 0) {
     408              :                                 // if it is not, insert it before the feature
     409          122 :                                 result.erase(remove(result.begin(), result.end(), fatherName), result.end());
     410          122 :                                 result.insert(it2, fatherName);
     411          122 :                                 ending = false;
     412              :                             }
     413              :                         }
     414              :                     }
     415         2802 :                 }
     416         1698 :             }
     417              :         }
     418          110 :         return result;
     419            0 :     }
     420            4 :     std::string Network::dump_cpt() const
     421              :     {
     422            4 :         std::stringstream oss;
     423           24 :         for (auto& node : nodes) {
     424           20 :             oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
     425           20 :             oss << node.second->getCPT() << std::endl;
     426              :         }
     427            8 :         return oss.str();
     428            4 :     }
     429              : }
        

Generated by: LCOV version 2.0-1

</html>