LCOV - code coverage report
Current view: top level - bayesnet/network - Network.cc (source / functions) Coverage Total Hit
Test: BayesNet Coverage Report Lines: 100.0 % 295 295
Test Date: 2024-05-06 17:54:04 Functions: 100.0 % 40 40
Legend: Lines: hit not hit

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include <thread>
       8              : #include <mutex>
       9              : #include <sstream>
      10              : #include "Network.h"
      11              : #include "bayesnet/utils/bayesnetUtils.h"
      12              : namespace bayesnet {
      13         2332 :     Network::Network() : fitted{ false }, maxThreads{ 0.95 }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
      14              :     {
      15         2332 :     }
      16            8 :     Network::Network(float maxT) : fitted{ false }, maxThreads{ maxT }, classNumStates{ 0 }, laplaceSmoothing{ 0 }
      17              :     {
      18              : 
      19            8 :     }
      20         2244 :     Network::Network(const Network& other) : laplaceSmoothing(other.laplaceSmoothing), features(other.features), className(other.className), classNumStates(other.getClassNumStates()),
      21         4488 :         maxThreads(other.getMaxThreads()), fitted(other.fitted), samples(other.samples)
      22              :     {
      23         2244 :         if (samples.defined())
      24            4 :             samples = samples.clone();
      25         2264 :         for (const auto& node : other.nodes) {
      26           20 :             nodes[node.first] = std::make_unique<Node>(*node.second);
      27              :         }
      28         2244 :     }
      29         1740 :     void Network::initialize()
      30              :     {
      31         1740 :         features.clear();
      32         1740 :         className = "";
      33         1740 :         classNumStates = 0;
      34         1740 :         fitted = false;
      35         1740 :         nodes.clear();
      36         1740 :         samples = torch::Tensor();
      37         1740 :     }
      38         2256 :     float Network::getMaxThreads() const
      39              :     {
      40         2256 :         return maxThreads;
      41              :     }
      42           48 :     torch::Tensor& Network::getSamples()
      43              :     {
      44           48 :         return samples;
      45              :     }
      46        31216 :     void Network::addNode(const std::string& name)
      47              :     {
      48        31216 :         if (name == "") {
      49            8 :             throw std::invalid_argument("Node name cannot be empty");
      50              :         }
      51        31208 :         if (nodes.find(name) != nodes.end()) {
      52            4 :             return;
      53              :         }
      54        31204 :         if (find(features.begin(), features.end(), name) == features.end()) {
      55        31204 :             features.push_back(name);
      56              :         }
      57        31204 :         nodes[name] = std::make_unique<Node>(name);
      58              :     }
      59          380 :     std::vector<std::string> Network::getFeatures() const
      60              :     {
      61          380 :         return features;
      62              :     }
      63         2616 :     int Network::getClassNumStates() const
      64              :     {
      65         2616 :         return classNumStates;
      66              :     }
      67           48 :     int Network::getStates() const
      68              :     {
      69           48 :         int result = 0;
      70          288 :         for (auto& node : nodes) {
      71          240 :             result += node.second->getNumStates();
      72              :         }
      73           48 :         return result;
      74              :     }
      75      3735008 :     std::string Network::getClassName() const
      76              :     {
      77      3735008 :         return className;
      78              :     }
      79        70324 :     bool Network::isCyclic(const std::string& nodeId, std::unordered_set<std::string>& visited, std::unordered_set<std::string>& recStack)
      80              :     {
      81        70324 :         if (visited.find(nodeId) == visited.end()) // if node hasn't been visited yet
      82              :         {
      83        70324 :             visited.insert(nodeId);
      84        70324 :             recStack.insert(nodeId);
      85        81496 :             for (Node* child : nodes[nodeId]->getChildren()) {
      86        11196 :                 if (visited.find(child->getName()) == visited.end() && isCyclic(child->getName(), visited, recStack))
      87           24 :                     return true;
      88        11180 :                 if (recStack.find(child->getName()) != recStack.end())
      89            8 :                     return true;
      90              :             }
      91              :         }
      92        70300 :         recStack.erase(nodeId); // remove node from recursion stack before function ends
      93        70300 :         return false;
      94              :     }
      95        59152 :     void Network::addEdge(const std::string& parent, const std::string& child)
      96              :     {
      97        59152 :         if (nodes.find(parent) == nodes.end()) {
      98            8 :             throw std::invalid_argument("Parent node " + parent + " does not exist");
      99              :         }
     100        59144 :         if (nodes.find(child) == nodes.end()) {
     101            8 :             throw std::invalid_argument("Child node " + child + " does not exist");
     102              :         }
     103              :         // Temporarily add edge to check for cycles
     104        59136 :         nodes[parent]->addChild(nodes[child].get());
     105        59136 :         nodes[child]->addParent(nodes[parent].get());
     106        59136 :         std::unordered_set<std::string> visited;
     107        59136 :         std::unordered_set<std::string> recStack;
     108        59136 :         if (isCyclic(nodes[child]->getName(), visited, recStack)) // if adding this edge forms a cycle
     109              :         {
     110              :             // remove problematic edge
     111            8 :             nodes[parent]->removeChild(nodes[child].get());
     112            8 :             nodes[child]->removeParent(nodes[parent].get());
     113            8 :             throw std::invalid_argument("Adding this edge forms a cycle in the graph.");
     114              :         }
     115        59144 :     }
     116      3735276 :     std::map<std::string, std::unique_ptr<Node>>& Network::getNodes()
     117              :     {
     118      3735276 :         return nodes;
     119              :     }
     120         1888 :     void Network::checkFitData(int n_samples, int n_features, int n_samples_y, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
     121              :     {
     122         1888 :         if (weights.size(0) != n_samples) {
     123            8 :             throw std::invalid_argument("Weights (" + std::to_string(weights.size(0)) + ") must have the same number of elements as samples (" + std::to_string(n_samples) + ") in Network::fit");
     124              :         }
     125         1880 :         if (n_samples != n_samples_y) {
     126            8 :             throw std::invalid_argument("X and y must have the same number of samples in Network::fit (" + std::to_string(n_samples) + " != " + std::to_string(n_samples_y) + ")");
     127              :         }
     128         1872 :         if (n_features != featureNames.size()) {
     129            8 :             throw std::invalid_argument("X and features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(featureNames.size()) + ")");
     130              :         }
     131         1864 :         if (features.size() == 0) {
     132            8 :             throw std::invalid_argument("The network has not been initialized. You must call addNode() before calling fit()");
     133              :         }
     134         1856 :         if (n_features != features.size() - 1) {
     135            8 :             throw std::invalid_argument("X and local features must have the same number of features in Network::fit (" + std::to_string(n_features) + " != " + std::to_string(features.size() - 1) + ")");
     136              :         }
     137         1848 :         if (find(features.begin(), features.end(), className) == features.end()) {
     138            8 :             throw std::invalid_argument("Class Name not found in Network::features");
     139              :         }
     140        32868 :         for (auto& feature : featureNames) {
     141        31044 :             if (find(features.begin(), features.end(), feature) == features.end()) {
     142            8 :                 throw std::invalid_argument("Feature " + feature + " not found in Network::features");
     143              :             }
     144        31036 :             if (states.find(feature) == states.end()) {
     145            8 :                 throw std::invalid_argument("Feature " + feature + " not found in states");
     146              :             }
     147              :         }
     148         1824 :     }
     149         1824 :     void Network::setStates(const std::map<std::string, std::vector<int>>& states)
     150              :     {
     151              :         // Set states to every Node in the network
     152         1824 :         for_each(features.begin(), features.end(), [this, &states](const std::string& feature) {
     153        32828 :             nodes.at(feature)->setNumStates(states.at(feature).size());
     154        32828 :             });
     155         1824 :         classNumStates = nodes.at(className)->getNumStates();
     156         1824 :     }
     157              :     // X comes in nxm, where n is the number of features and m the number of samples
     158            4 :     void Network::fit(const torch::Tensor& X, const torch::Tensor& y, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     159              :     {
     160            4 :         checkFitData(X.size(1), X.size(0), y.size(0), featureNames, className, states, weights);
     161            4 :         this->className = className;
     162            4 :         torch::Tensor ytmp = torch::transpose(y.view({ y.size(0), 1 }), 0, 1);
     163           12 :         samples = torch::cat({ X , ytmp }, 0);
     164           20 :         for (int i = 0; i < featureNames.size(); ++i) {
     165           48 :             auto row_feature = X.index({ i, "..." });
     166           16 :         }
     167            4 :         completeFit(states, weights);
     168           24 :     }
     169         1792 :     void Network::fit(const torch::Tensor& samples, const torch::Tensor& weights, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     170              :     {
     171         1792 :         checkFitData(samples.size(1), samples.size(0) - 1, samples.size(1), featureNames, className, states, weights);
     172         1792 :         this->className = className;
     173         1792 :         this->samples = samples;
     174         1792 :         completeFit(states, weights);
     175         1792 :     }
     176              :     // input_data comes in nxm, where n is the number of features and m the number of samples
     177           92 :     void Network::fit(const std::vector<std::vector<int>>& input_data, const std::vector<int>& labels, const std::vector<double>& weights_, const std::vector<std::string>& featureNames, const std::string& className, const std::map<std::string, std::vector<int>>& states)
     178              :     {
     179           92 :         const torch::Tensor weights = torch::tensor(weights_, torch::kFloat64);
     180           92 :         checkFitData(input_data[0].size(), input_data.size(), labels.size(), featureNames, className, states, weights);
     181           28 :         this->className = className;
     182              :         // Build tensor of samples (nxm) (n+1 because of the class)
     183           28 :         samples = torch::zeros({ static_cast<int>(input_data.size() + 1), static_cast<int>(input_data[0].size()) }, torch::kInt32);
     184          140 :         for (int i = 0; i < featureNames.size(); ++i) {
     185          448 :             samples.index_put_({ i, "..." }, torch::tensor(input_data[i], torch::kInt32));
     186              :         }
     187          112 :         samples.index_put_({ -1, "..." }, torch::tensor(labels, torch::kInt32));
     188           28 :         completeFit(states, weights);
     189          232 :     }
     190         1824 :     void Network::completeFit(const std::map<std::string, std::vector<int>>& states, const torch::Tensor& weights)
     191              :     {
     192         1824 :         setStates(states);
     193         1824 :         laplaceSmoothing = 1.0 / samples.size(1); // To use in CPT computation
     194         1824 :         std::vector<std::thread> threads;
     195        34652 :         for (auto& node : nodes) {
     196        32828 :             threads.emplace_back([this, &node, &weights]() {
     197        32828 :                 node.second->computeCPT(samples, features, laplaceSmoothing, weights);
     198        32828 :                 });
     199              :         }
     200        34652 :         for (auto& thread : threads) {
     201        32828 :             thread.join();
     202              :         }
     203         1824 :         fitted = true;
     204         1824 :     }
     205         3320 :     torch::Tensor Network::predict_tensor(const torch::Tensor& samples, const bool proba)
     206              :     {
     207         3320 :         if (!fitted) {
     208            8 :             throw std::logic_error("You must call fit() before calling predict()");
     209              :         }
     210         3312 :         torch::Tensor result;
     211         3312 :         result = torch::zeros({ samples.size(1), classNumStates }, torch::kFloat64);
     212       785016 :         for (int i = 0; i < samples.size(1); ++i) {
     213      2345136 :             const torch::Tensor sample = samples.index({ "...", i });
     214       781712 :             auto psample = predict_sample(sample);
     215       781704 :             auto temp = torch::tensor(psample, torch::kFloat64);
     216              :             //            result.index_put_({ i, "..." }, torch::tensor(predict_sample(sample), torch::kFloat64));
     217      2345112 :             result.index_put_({ i, "..." }, temp);
     218       781712 :         }
     219         3304 :         if (proba)
     220         1476 :             return result;
     221         3656 :         return result.argmax(1);
     222      1566728 :     }
     223              :     // Return mxn tensor of probabilities
     224         1476 :     torch::Tensor Network::predict_proba(const torch::Tensor& samples)
     225              :     {
     226         1476 :         return predict_tensor(samples, true);
     227              :     }
     228              : 
     229              :     // Return mxn tensor of probabilities
     230         1844 :     torch::Tensor Network::predict(const torch::Tensor& samples)
     231              :     {
     232         1844 :         return predict_tensor(samples, false);
     233              :     }
     234              : 
     235              :     // Return mx1 std::vector of predictions
     236              :     // tsamples is nxm std::vector of samples
     237           48 :     std::vector<int> Network::predict(const std::vector<std::vector<int>>& tsamples)
     238              :     {
     239           48 :         if (!fitted) {
     240           16 :             throw std::logic_error("You must call fit() before calling predict()");
     241              :         }
     242           32 :         std::vector<int> predictions;
     243           32 :         std::vector<int> sample;
     244         3564 :         for (int row = 0; row < tsamples[0].size(); ++row) {
     245         3540 :             sample.clear();
     246        26252 :             for (int col = 0; col < tsamples.size(); ++col) {
     247        22712 :                 sample.push_back(tsamples[col][row]);
     248              :             }
     249         3540 :             std::vector<double> classProbabilities = predict_sample(sample);
     250              :             // Find the class with the maximum posterior probability
     251         3532 :             auto maxElem = max_element(classProbabilities.begin(), classProbabilities.end());
     252         3532 :             int predictedClass = distance(classProbabilities.begin(), maxElem);
     253         3532 :             predictions.push_back(predictedClass);
     254         3532 :         }
     255           48 :         return predictions;
     256           40 :     }
     257              :     // Return mxn std::vector of probabilities
     258              :     // tsamples is nxm std::vector of samples
     259          552 :     std::vector<std::vector<double>> Network::predict_proba(const std::vector<std::vector<int>>& tsamples)
     260              :     {
     261          552 :         if (!fitted) {
     262            8 :             throw std::logic_error("You must call fit() before calling predict_proba()");
     263              :         }
     264          544 :         std::vector<std::vector<double>> predictions;
     265          544 :         std::vector<int> sample;
     266       111516 :         for (int row = 0; row < tsamples[0].size(); ++row) {
     267       110972 :             sample.clear();
     268      1055620 :             for (int col = 0; col < tsamples.size(); ++col) {
     269       944648 :                 sample.push_back(tsamples[col][row]);
     270              :             }
     271       110972 :             predictions.push_back(predict_sample(sample));
     272              :         }
     273         1088 :         return predictions;
     274          544 :     }
     275           20 :     double Network::score(const std::vector<std::vector<int>>& tsamples, const std::vector<int>& labels)
     276              :     {
     277           20 :         std::vector<int> y_pred = predict(tsamples);
     278           12 :         int correct = 0;
     279         2324 :         for (int i = 0; i < y_pred.size(); ++i) {
     280         2312 :             if (y_pred[i] == labels[i]) {
     281         1944 :                 correct++;
     282              :             }
     283              :         }
     284           24 :         return (double)correct / y_pred.size();
     285           12 :     }
     286              :     // Return 1xn std::vector of probabilities
     287       114512 :     std::vector<double> Network::predict_sample(const std::vector<int>& sample)
     288              :     {
     289              :         // Ensure the sample size is equal to the number of features
     290       114512 :         if (sample.size() != features.size() - 1) {
     291           16 :             throw std::invalid_argument("Sample size (" + std::to_string(sample.size()) +
     292           24 :                 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
     293              :         }
     294       114504 :         std::map<std::string, int> evidence;
     295      1081840 :         for (int i = 0; i < sample.size(); ++i) {
     296       967336 :             evidence[features[i]] = sample[i];
     297              :         }
     298       229008 :         return exactInference(evidence);
     299       114504 :     }
     300              :     // Return 1xn std::vector of probabilities
     301       781712 :     std::vector<double> Network::predict_sample(const torch::Tensor& sample)
     302              :     {
     303              :         // Ensure the sample size is equal to the number of features
     304       781712 :         if (sample.size(0) != features.size() - 1) {
     305           16 :             throw std::invalid_argument("Sample size (" + std::to_string(sample.size(0)) +
     306           24 :                 ") does not match the number of features (" + std::to_string(features.size() - 1) + ")");
     307              :         }
     308       781704 :         std::map<std::string, int> evidence;
     309     18085136 :         for (int i = 0; i < sample.size(0); ++i) {
     310     17303432 :             evidence[features[i]] = sample[i].item<int>();
     311              :         }
     312      1563408 :         return exactInference(evidence);
     313       781704 :     }
     314      3734984 :     double Network::computeFactor(std::map<std::string, int>& completeEvidence)
     315              :     {
     316      3734984 :         double result = 1.0;
     317     72886736 :         for (auto& node : getNodes()) {
     318     69151752 :             result *= node.second->getFactorValue(completeEvidence);
     319              :         }
     320      3734984 :         return result;
     321              :     }
     322       896208 :     std::vector<double> Network::exactInference(std::map<std::string, int>& evidence)
     323              :     {
     324       896208 :         std::vector<double> result(classNumStates, 0.0);
     325       896208 :         std::vector<std::thread> threads;
     326       896208 :         std::mutex mtx;
     327      4631192 :         for (int i = 0; i < classNumStates; ++i) {
     328      3734984 :             threads.emplace_back([this, &result, &evidence, i, &mtx]() {
     329      3734984 :                 auto completeEvidence = std::map<std::string, int>(evidence);
     330      3734984 :                 completeEvidence[getClassName()] = i;
     331      3734984 :                 double factor = computeFactor(completeEvidence);
     332      3734984 :                 std::lock_guard<std::mutex> lock(mtx);
     333      3734984 :                 result[i] = factor;
     334      3734984 :                 });
     335              :         }
     336      4631192 :         for (auto& thread : threads) {
     337      3734984 :             thread.join();
     338              :         }
     339              :         // Normalize result
     340       896208 :         double sum = accumulate(result.begin(), result.end(), 0.0);
     341      4631192 :         transform(result.begin(), result.end(), result.begin(), [sum](const double& value) { return value / sum; });
     342      1792416 :         return result;
     343       896208 :     }
     344           28 :     std::vector<std::string> Network::show() const
     345              :     {
     346           28 :         std::vector<std::string> result;
     347              :         // Draw the network
     348          160 :         for (auto& node : nodes) {
     349          132 :             std::string line = node.first + " -> ";
     350          308 :             for (auto child : node.second->getChildren()) {
     351          176 :                 line += child->getName() + ", ";
     352              :             }
     353          132 :             result.push_back(line);
     354          132 :         }
     355           56 :         return result;
     356           28 :     }
     357          112 :     std::vector<std::string> Network::graph(const std::string& title) const
     358              :     {
     359          112 :         auto output = std::vector<std::string>();
     360          112 :         auto prefix = "digraph BayesNet {\nlabel=<BayesNet ";
     361          112 :         auto suffix = ">\nfontsize=30\nfontcolor=blue\nlabelloc=t\nlayout=circo\n";
     362          112 :         std::string header = prefix + title + suffix;
     363          112 :         output.push_back(header);
     364          844 :         for (auto& node : nodes) {
     365          732 :             auto result = node.second->graph(className);
     366          732 :             output.insert(output.end(), result.begin(), result.end());
     367          732 :         }
     368          112 :         output.push_back("}\n");
     369          224 :         return output;
     370          112 :     }
     371          408 :     std::vector<std::pair<std::string, std::string>> Network::getEdges() const
     372              :     {
     373          408 :         auto edges = std::vector<std::pair<std::string, std::string>>();
     374         7396 :         for (const auto& node : nodes) {
     375         6988 :             auto head = node.first;
     376        20312 :             for (const auto& child : node.second->getChildren()) {
     377        13324 :                 auto tail = child->getName();
     378        13324 :                 edges.push_back({ head, tail });
     379        13324 :             }
     380         6988 :         }
     381          816 :         return edges;
     382          408 :     }
     383          364 :     int Network::getNumEdges() const
     384              :     {
     385          364 :         return getEdges().size();
     386              :     }
     387          220 :     std::vector<std::string> Network::topological_sort()
     388              :     {
     389              :         /* Check if al the fathers of every node are before the node */
     390          220 :         auto result = features;
     391          220 :         result.erase(remove(result.begin(), result.end(), className), result.end());
     392          220 :         bool ending{ false };
     393          628 :         while (!ending) {
     394          408 :             ending = true;
     395         3804 :             for (auto feature : features) {
     396         3396 :                 auto fathers = nodes[feature]->getParents();
     397         9000 :                 for (const auto& father : fathers) {
     398         5604 :                     auto fatherName = father->getName();
     399         5604 :                     if (fatherName == className) {
     400         2980 :                         continue;
     401              :                     }
     402              :                     // Check if father is placed before the actual feature
     403         2624 :                     auto it = find(result.begin(), result.end(), fatherName);
     404         2624 :                     if (it != result.end()) {
     405         2624 :                         auto it2 = find(result.begin(), result.end(), feature);
     406         2624 :                         if (it2 != result.end()) {
     407         5248 :                             if (distance(it, it2) < 0) {
     408              :                                 // if it is not, insert it before the feature
     409          244 :                                 result.erase(remove(result.begin(), result.end(), fatherName), result.end());
     410          244 :                                 result.insert(it2, fatherName);
     411          244 :                                 ending = false;
     412              :                             }
     413              :                         }
     414              :                     }
     415         5604 :                 }
     416         3396 :             }
     417              :         }
     418          440 :         return result;
     419          220 :     }
     420            8 :     std::string Network::dump_cpt() const
     421              :     {
     422            8 :         std::stringstream oss;
     423           48 :         for (auto& node : nodes) {
     424           40 :             oss << "* " << node.first << ": (" << node.second->getNumStates() << ") : " << node.second->getCPT().sizes() << std::endl;
     425           40 :             oss << node.second->getCPT() << std::endl;
     426              :         }
     427           16 :         return oss.str();
     428            8 :     }
     429              : }
        

Generated by: LCOV version 2.0-1