LCOV - code coverage report
Current view: top level - bayesnet/network - Network.cc (source / functions) Coverage Total Hit
Test: coverage.info Lines: 97.6 % 297 290
Test Date: 2024-04-21 17:30:26 Functions: 100.0 % 40 40

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include <thread>
       8              : #include <mutex>
       9              : #include <sstream>
      10              : #include "Network.h"
      11              : #include "bayesnet/utils/bayesnetUtils.h"
      12              : namespace bayesnet {
      13          435 :     Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
      14              :     {
      15          435 :     }
      16            2 :     Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
      17              :     {
      18              : 
      19            2 :     }
      20          414 :     Network::Network(const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
      21          828 :         maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
      22              :     {
      23          414 :         if (samples.defined())
      24            1 :             samples = samples.clone();
      25          419 :         for (const auto& node : other.nodes) {
      26            5 :             nodes[node.first] = std::make_unique<Node>(*node.second);
      27              :         }
      28          414 :     }
      29          286 :     void Network::initialize()
      30              :     {
      31          286 :         features.clear();
      32          286 :         className = "";
      33          286 :         classNumStates = 0;
      34          286 :         fitted = false;
      35          286 :         nodes.clear();
      36          286 :         samples = torch::Tensor();
      37          286 :     }
      38          417 :     float Network::getMaxThreads() const
      39              :     {
      40          417 :         return maxThreads;
      41              :     }
      42           12 :     torch::Tensor& Network::getSamples()
      43              :     {
      44           12 :         return samples;
      45              :     }
      46         8878 :     void Network::addNode(const std::string& name)
      47              :     {
      48         8878 :         if (name == "") {
      49            2 :             throw std::invalid_argument("Node name cannot be empty");
      50              :         }
      51         8876 :         if (nodes.find(name) != nodes.end()) {
      52            0 :             return;
      53              :         }
      54         8876 :         if (find(features.begin(), features.end(), name) == features.end()) {
      55         8876 :             features.push_back(name);
      56              :         }
      57         8876 :         nodes[name] = std::make_unique<Node>(name);
      58              :     }
      59           52 :     std::vector<std::string> Network::getFeatures() const
      60              :     {
      61           52 :         return features;
      62              :     }
      63          496 :     int Network::getClassNumStates() const
      64              :     {
      65          496 :         return classNumStates;
      66              :     }
      67           12 :     int Network::getStates() const
      68              :     {
      69           12 :         int result = 0;
      70           72 :         for (auto& node : nodes) {
      71           60 :             result += node.second->getNumStates();
      72              :         }
      73           12 :         return result;
      74              :     }
      75       437774 :     std::string Network::getClassName() const
      76              :     {
      77       437774 :         return className;
      78              :     }
      79        22324 :     bool Network::isCyclic(const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack)
      80              :     {
      81        22324 :         if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
      82              :         {
      83        22324 :             visited.insert(nodeId);
      84        22324 :             recStack.insert(nodeId);
      85        27702 :             for (Node* child : nodes[nodeId]->getChildren()) {
      86         5384 :                 if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
      87            6 :                     return true;
      88         5380 :                 if (recStack.find(child->getName()) != recStack.end())
      89            2 :                     return true;
      90              :             }
      91              :         }
      92        22318 :         recStack.erase(nodeId); // remove node from recursion stack before function ends
      93        22318 :         return false;
      94              :     }
      95        16946 :     void Network::addEdge(const std::string& parent, const std::string& child)
      96              :     {
      97        16946 :         if (nodes.find(parent) == nodes.end()) {
      98            2 :             throw std::invalid_argument("Parent node " + parent + " does not exist");
      99              :         }
     100        16944 :         if (nodes.find(child) == nodes.end()) {
     101            2 :             throw std::invalid_argument("Child node " + child + " does not exist");
     102              :         }
     103              :         // Temporarily add edge to check for cycles
     104        16942 :         nodes[parent]->addChild(nodes[child].get());
     105        16942 :         nodes[child]->addParent(nodes[parent].get());
     106        16942 :         std::unordered_set<std::string> visited;
     107        16942 :         std::unordered_set<std::string> recStack;
     108        16942 :         if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
     109              :         {
     110              :             // remove problematic edge
     111            2 :             nodes[parent]->removeChild(nodes[child].get());
     112            2 :             nodes[child]->removeParent(nodes[parent].get());
     113            2 :             throw std::invalid_argument("Adding this edge forms a cycle in the graph.");
     114              :         }
     115        16944 :     }
     116       437841 :     std::map<std::string, std::unique_ptr<Node>>& Network::getNodes()
     117              :     {
     118       437841 :         return nodes;
     119              :     }
     120          327 :     void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
     121              :     {
     122          327 :         if (weights.size(0) != n_samples) {
     123            2 :             throw std::invalid_argument("Weights (" + std::to_string(weights.size(0)) + ") must have the same number of elements as samples (" + std::to_string(n_samples) + ") in Network::fit");
     124              :         }
     125          325 :         if (n_samples != n_samples_y) {
     126            2 :             throw std::invalid_argument("X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) + " != " + std::to_string(n_samples_y) + ")");
     127              :         }
     128          323 :         if (n_features != featureNames.size()) {
     129            2 :             throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
     130              :         }
     131          321 :         if (features.size() == 0) {
     132            2 :             throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()");
     133              :         }
     134          319 :         if (n_features != features.size() - 1) {
     135            2 :             throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
     136              :         }
     137          317 :         if (find(features.begin(), features.end(), className) == features.end()) {
     138            2 :             throw std::invalid_argument("Class Name not found in Network::features");
     139              :         }
     140         9296 :         for (auto& feature : featureNames) {
     141         8983 :             if (find(features.begin(), features.end(), feature) == features.end()) {
     142            2 :                 throw std::invalid_argument("Feature " + feature + " not found in Network::features");
     143              :             }
     144         8981 :             if (states.find(feature) == states.end()) {
     145            0 :                 throw std::invalid_argument("Feature " + feature + " not found in states");
     146              :             }
     147              :         }
     148          313 :     }
     149          313 :     void Network::setStates(const std::map<std::string, std::vector<int>>& states)
     150              :     {
     151              :         // Set states to every Node in the network
     152          313 :         for_each(features.begin(), features.end(), [this, &states](const std::string& feature) {
     153         9288 :             nodes.at(feature)->setNumStates(states.at(feature).size());
     154         9288 :             });
     155          313 :         classNumStates = nodes.at(className)->getNumStates();
     156          313 :     }
     157              :     // X comes in nxm, where n is the number of features and m the number of samples
     158            1 :     void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     159              :     {
     160            1 :         checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
     161            1 :         this->className = className;
     162            1 :         torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
     163            3 :         samples = torch::cat({ X , ytmp }, 0);
     164            5 :         for (int i = 0; i < featureNames.size(); ++i) {
     165           12 :             auto row_feature = X.index({ i, "..." });
     166            4 :         }
     167            1 :         completeFit(states, weights);
     168            6 :     }
     169          305 :     void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     170              :     {
     171          305 :         checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
     172          305 :         this->className = className;
     173          305 :         this->samples = samples;
     174          305 :         completeFit(states, weights);
     175          305 :     }
     176              :     // input_data comes in nxm, where n is the number of features and m the number of samples
     177           21 :     void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     178              :     {
     179           21 :         const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
     180           21 :         checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
     181            7 :         this->className = className;
     182              :         // Build tensor of samples (nxm) (n+1 because of the class)
     183            7 :         samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32);
     184           35 :         for (int i = 0; i < featureNames.size(); ++i) {
     185          112 :             samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
     186              :         }
     187           28 :         samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
     188            7 :         completeFit(states, weights);
     189           56 :     }
     190          313 :     void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
     191              :     {
     192          313 :         setStates(states);
     193          313 :         laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
     194          313 :         std::vector<std::thread> threads;
     195         9601 :         for (auto& node : nodes) {
     196         9288 :             threads.emplace_back([this, &node, &weights]() {
     197         9288 :                 node.second->computeCPT(samples, features, laplaceSmoothing, weights);
     198         9288 :                 });
     199              :         }
     200         9601 :         for (auto& thread : threads) {
     201         9288 :             thread.join();
     202              :         }
     203          313 :         fitted = true;
     204          313 :     }
     205          549 :     torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
     206              :     {
     207          549 :         if (!fitted) {
     208            2 :             throw std::logic_error("You must call fit() before calling predict()");
     209              :         }
     210          547 :         torch::Tensor result;
     211          547 :         result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
     212        96385 :         for (int i = 0; i < samples.size(1); ++i) {
     213       287520 :             const torch::Tensor sample = samples.index({ "...", i });
     214        95840 :             auto psample = predict_sample(sample);
     215        95838 :             auto temp = torch::tensor(psample, torch::kFloat64);
     216              :             //            result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64));
     217       287514 :             result.index_put_({ i, "..." }, temp);
     218        95840 :         }
     219          545 :         if (proba)
     220          304 :             return result;
     221          482 :         return result.argmax(1);
     222       192225 :     }
     223              :     // Return mxn tensor of probabilities
     224          304 :     torch::Tensor Network::predict_proba(const torch::Tensor& samples)
     225              :     {
     226          304 :         return predict_tensor(samples, true);
     227              :     }
     228              : 
     229              :     // Return mxn tensor of probabilities
     230          245 :     torch::Tensor Network::predict(const torch::Tensor& samples)
     231              :     {
     232          245 :         return predict_tensor(samples, false);
     233              :     }
     234              : 
     235              :     // Return mx1 std::vector of predictions
     236              :     // tsamples is nxm std::vector of samples
     237           12 :     std::vector<int> Network::predict(const std::vector<std::vector<int>>& tsamples)
     238              :     {
     239           12 :         if (!fitted) {
     240            4 :             throw std::logic_error("You must call fit() before calling predict()");
     241              :         }
     242            8 :         std::vector<int> predictions;
     243            8 :         std::vector<int> sample;
     244          891 :         for (int row = 0; row < tsamples[0].size(); ++row) {
     245          885 :             sample.clear();
     246         6563 :             for (int col = 0; col < tsamples.size(); ++col) {
     247         5678 :                 sample.push_back(tsamples[col][row]);
     248              :             }
     249          885 :             std::vector<double> classProbabilities = predict_sample(sample);
     250              :             // Find the class with the maximum posterior probability
     251          883 :             auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
     252          883 :             int predictedClass = distance(classProbabilities.begin(), maxElem);
     253          883 :             predictions.push_back(predictedClass);
     254          883 :         }
     255           12 :         return predictions;
     256           10 :     }
     257              :     // Return mxn std::vector of probabilities
     258              :     // tsamples is nxm std::vector of samples
     259           68 :     std::vector<std::vector<double>> Network::predict_proba(const std::vector<std::vector<int>>& tsamples)
     260              :     {
     261           68 :         if (!fitted) {
     262            2 :             throw std::logic_error("You must call fit() before calling predict_proba()");
     263              :         }
     264           66 :         std::vector<std::vector<double>> predictions;
     265           66 :         std::vector<int> sample;
     266        12787 :         for (int row = 0; row < tsamples[0].size(); ++row) {
     267        12721 :             sample.clear();
     268       193587 :             for (int col = 0; col < tsamples.size(); ++col) {
     269       180866 :                 sample.push_back(tsamples[col][row]);
     270              :             }
     271        12721 :             predictions.push_back(predict_sample(sample));
     272              :         }
     273          132 :         return predictions;
     274           66 :     }
     275            5 :     double Network::score(const std::vector<std::vector<int>>& tsamples, const std::vector<int>& labels)
     276              :     {
     277            5 :         std::vector<int> y_pred = predict(tsamples);
     278            3 :         int correct = 0;
     279          581 :         for (int i = 0; i < y_pred.size(); ++i) {
     280          578 :             if (y_pred[i] == labels[i]) {
     281          486 :                 correct++;
     282              :             }
     283              :         }
     284            6 :         return (double)correct / y_pred.size();
     285            3 :     }
     286              :     // Return 1xn std::vector of probabilities
     287        13606 :     std::vector<double> Network::predict_sample(const std::vector<int>& sample)
     288              :     {
     289              :         // Ensure the sample size is equal to the number of features
     290        13606 :         if (sample.size() != features.size() - 1) {
     291            4 :             throw std::invalid_argument("Sample size (" + std::to_string(sample.size()) +
     292            6 :                 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
     293              :         }
     294        13604 :         std::map<std::string, int> evidence;
     295       200142 :         for (int i = 0; i < sample.size(); ++i) {
     296       186538 :             evidence[features[i]] = sample[i];
     297              :         }
     298        27208 :         return exactInference(evidence);
     299        13604 :     }
     300              :     // Return 1xn std::vector of probabilities
     301        95840 :     std::vector<double> Network::predict_sample(const torch::Tensor& sample)
     302              :     {
     303              :         // Ensure the sample size is equal to the number of features
     304        95840 :         if (sample.size(0) != features.size() - 1) {
     305            4 :             throw std::invalid_argument("Sample size (" + std::to_string(sample.size(0)) +
     306            6 :                 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
     307              :         }
     308        95838 :         std::map<std::string, int> evidence;
     309      2448008 :         for (int i = 0; i < sample.size(0); ++i) {
     310      2352170 :             evidence[features[i]] = sample[i].item<int>();
     311              :         }
     312       191676 :         return exactInference(evidence);
     313        95838 :     }
     314       437768 :     double Network::computeFactor(std::map<std::string, int>& completeEvidence)
     315              :     {
     316       437768 :         double result = 1.0;
     317      6084992 :         for (auto& node : getNodes()) {
     318      5647224 :             result *= node.second->getFactorValue(completeEvidence);
     319              :         }
     320       437768 :         return result;
     321              :     }
     322       109442 :     std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
     323              :     {
     324       109442 :         std::vector<double> result(classNumStates, 0.0);
     325       109442 :         std::vector<std::thread> threads;
     326       109442 :         std::mutex mtx;
     327       547210 :         for (int i = 0; i < classNumStates; ++i) {
     328       437768 :             threads.emplace_back([this, &result, &evidence, i, &mtx]() {
     329       437768 :                 auto completeEvidence = std::map<std::string, int>(evidence);
     330       437768 :                 completeEvidence[getClassName()] = i;
     331       437768 :                 double factor = computeFactor(completeEvidence);
     332       437768 :                 std::lock_guard<std::mutex> lock(mtx);
     333       437768 :                 result[i] = factor;
     334       437768 :                 });
     335              :         }
     336       547210 :         for (auto& thread : threads) {
     337       437768 :             thread.join();
     338              :         }
     339              :         // Normalize result
     340       109442 :         double sum = accumulate(result.begin(), result.end(), 0.0);
     341       547210 :         transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; });
     342       218884 :         return result;
     343       109442 :     }
     344            7 :     std::vector<std::string> Network::show() const
     345              :     {
     346            7 :         std::vector<std::string> result;
     347              :         // Draw the network
     348           40 :         for (auto& node : nodes) {
     349           33 :             std::string line = node.first + " -> ";
     350           77 :             for (auto child : node.second->getChildren()) {
     351           44 :                 line += child->getName() + ", ";
     352              :             }
     353           33 :             result.push_back(line);
     354           33 :         }
     355            7 :         return result;
     356            0 :     }
     357           22 :     std::vector<std::string> Network::graph(const std::string& title) const
     358              :     {
     359           22 :         auto output = std::vector<std::string>();
     360           22 :         auto prefix = "digraph BayesNet {\nlabel=<BayesNet ";
     361           22 :         auto suffix = ">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
     362           22 :         std::string header = prefix + title + suffix;
     363           22 :         output.push_back(header);
     364          175 :         for (auto& node : nodes) {
     365          153 :             auto result = node.second->graph(className);
     366          153 :             output.insert(output.end(), result.begin(), result.end());
     367          153 :         }
     368           22 :         output.push_back("}\n");
     369           44 :         return output;
     370           22 :     }
     371           59 :     std::vector<std::pair<std::string, std::string>> Network::getEdges() const
     372              :     {
     373           59 :         auto edges = std::vector<std::pair<std::string, std::string>>();
     374          937 :         for (const auto& node : nodes) {
     375          878 :             auto head = node.first;
     376         2456 :             for (const auto& child : node.second->getChildren()) {
     377         1578 :                 auto tail = child->getName();
     378         1578 :                 edges.push_back({ head, tail });
     379         1578 :             }
     380          878 :         }
     381           59 :         return edges;
     382            0 :     }
     383           48 :     int Network::getNumEdges() const
     384              :     {
     385           48 :         return getEdges().size();
     386              :     }
     387           55 :     std::vector<std::string> Network::topological_sort()
     388              :     {
     389              :         /* Check if al the fathers of every node are before the node */
     390           55 :         auto result = features;
     391           55 :         result.erase(remove(result.begin(), result.end(), className), result.end());
     392           55 :         bool ending{ false };
     393          157 :         while (!ending) {
     394          102 :             ending = true;
     395          951 :             for (auto feature : features) {
     396          849 :                 auto fathers = nodes[feature]->getParents();
     397         2250 :                 for (const auto& father : fathers) {
     398         1401 :                     auto fatherName = father->getName();
     399         1401 :                     if (fatherName == className) {
     400          745 :                         continue;
     401              :                     }
     402              :                     // Check if father is placed before the actual feature
     403          656 :                     auto it = find(result.begin(), result.end(), fatherName);
     404          656 :                     if (it != result.end()) {
     405          656 :                         auto it2 = find(result.begin(), result.end(), feature);
     406          656 :                         if (it2 != result.end()) {
     407          656 :                             if (distance(it, it2) < 0) {
     408              :                                 // if it is not, insert it before the feature
     409           61 :                                 result.erase(remove(result.begin(), result.end(), fatherName), result.end());
     410           61 :                                 result.insert(it2, fatherName);
     411           61 :                                 ending = false;
     412              :                             }
     413              :                         } else {
     414            0 :                             throw std::logic_error("Error in topological sort because of node " + feature + " is not in result");
     415              :                         }
     416              :                     } else {
     417            0 :                         throw std::logic_error("Error in topological sort because of node father " + fatherName + " is not in result");
     418              :                     }
     419         1401 :                 }
     420          849 :             }
     421              :         }
     422           55 :         return result;
     423            0 :     }
     424            2 :     std::string Network::dump_cpt() const
     425              :     {
     426            2 :         std::stringstream oss;
     427           12 :         for (auto& node : nodes) {
     428           10 :             oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
     429           10 :             oss << node.second->getCPT() << std::endl;
     430              :         }
     431            4 :         return oss.str();
     432            2 :     }
     433              : }
        

Generated by: LCOV version 2.0-1