Files
BayesNet/html/bayesnet/classifiers/KDB.cc.gcov.html

17 KiB

<html lang="en"> <head> </head>
LCOV - code coverage report
Current view: top level - bayesnet/classifiers - KDB.cc (source / functions) Coverage Total Hit
Test: coverage.info Lines: 96.3 % 54 52
Test Date: 2024-04-21 17:30:26 Functions: 100.0 % 5 5

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include "KDB.h"
       8              : 
       9              : namespace bayesnet {
      10           37 :     KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta)
      11              :     {
      12          111 :         validHyperparameters = { "k", "theta" };
      13              : 
      14          111 :     }
      15            3 :     void KDB::setHyperparameters(const nlohmann::json& hyperparameters_)
      16              :     {
      17            3 :         auto hyperparameters = hyperparameters_;
      18            3 :         if (hyperparameters.contains("k")) {
      19            1 :             k = hyperparameters["k"];
      20            1 :             hyperparameters.erase("k");
      21              :         }
      22            3 :         if (hyperparameters.contains("theta")) {
      23            1 :             theta = hyperparameters["theta"];
      24            1 :             hyperparameters.erase("theta");
      25              :         }
      26            3 :         Classifier::setHyperparameters(hyperparameters);
      27            3 :     }
      28           13 :     void KDB::buildModel(const torch::Tensor& weights)
      29              :     {
      30              :         /*
      31              :         1. For each feature Xi, compute mutual information, I(X;C),
      32              :         where C is the class.
      33              :         2. Compute class conditional mutual information I(Xi;XjIC), f or each
      34              :         pair of features Xi and Xj, where i#j.
      35              :         3. Let the used variable list, S, be empty.
      36              :         4. Let the DAG network being constructed, BN, begin with a single
      37              :         class node, C.
      38              :         5. Repeat until S includes all domain features
      39              :         5.1. Select feature Xmax which is not in S and has the largest value
      40              :         I(Xmax;C).
      41              :         5.2. Add a node to BN representing Xmax.
      42              :         5.3. Add an arc from C to Xmax in BN.
      43              :         5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
      44              :         the highest value for I(Xmax;X,jC).
      45              :         5.5. Add Xmax to S.
      46              :         Compute the conditional probabilility infered by the structure of BN by
      47              :         using counts from DB, and output BN.
      48              :         */
      49              :         // 1. For each feature Xi, compute mutual information, I(X;C),
      50              :         // where C is the class.
      51           13 :         addNodes();
      52           39 :         const torch::Tensor& y = dataset.index({ -1, "..." });
      53           13 :         std::vector<double> mi;
      54           99 :         for (auto i = 0; i < features.size(); i++) {
      55          258 :             torch::Tensor firstFeature = dataset.index({ i, "..." });
      56           86 :             mi.push_back(metrics.mutualInformation(firstFeature, y, weights));
      57           86 :         }
      58              :         // 2. Compute class conditional mutual information I(Xi;XjIC), f or each
      59           13 :         auto conditionalEdgeWeights = metrics.conditionalEdge(weights);
      60              :         // 3. Let the used variable list, S, be empty.
      61           13 :         std::vector<int> S;
      62              :         // 4. Let the DAG network being constructed, BN, begin with a single
      63              :         // class node, C.
      64              :         // 5. Repeat until S includes all domain features
      65              :         // 5.1. Select feature Xmax which is not in S and has the largest value
      66              :         // I(Xmax;C).
      67           13 :         auto order = argsort(mi);
      68           99 :         for (auto idx : order) {
      69              :             // 5.2. Add a node to BN representing Xmax.
      70              :             // 5.3. Add an arc from C to Xmax in BN.
      71           86 :             model.addEdge(className, features[idx]);
      72              :             // 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
      73              :             // the highest value for I(Xmax;X,jC).
      74           86 :             add_m_edges(idx, S, conditionalEdgeWeights);
      75              :             // 5.5. Add Xmax to S.
      76           86 :             S.push_back(idx);
      77              :         }
      78          112 :     }
      79           86 :     void KDB::add_m_edges(int idx, std::vector<int>& S, torch::Tensor& weights)
      80              :     {
      81           86 :         auto n_edges = std::min(k, static_cast<int>(S.size()));
      82           86 :         auto cond_w = clone(weights);
      83           86 :         bool exit_cond = k == 0;
      84           86 :         int num = 0;
      85          251 :         while (!exit_cond) {
      86          660 :             auto max_minfo = argmax(cond_w.index({ idx, "..." })).item<int>();
      87          165 :             auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
      88          441 :             if (belongs && cond_w.index({ idx, max_minfo }).item<float>() > theta) {
      89              :                 try {
      90           80 :                     model.addEdge(features[max_minfo], features[idx]);
      91           80 :                     num++;
      92              :                 }
      93            0 :                 catch (const std::invalid_argument& e) {
      94              :                     // Loops are not allowed
      95            0 :                 }
      96              :             }
      97          660 :             cond_w.index_put_({ idx, max_minfo }, -1);
      98          495 :             auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta);
      99          165 :             auto candidates = candidates_mask.nonzero();
     100          165 :             exit_cond = num == n_edges || candidates.size(0) == 0;
     101          165 :         }
     102          673 :     }
     103            2 :     std::vector<std::string> KDB::graph(const std::string& title) const
     104              :     {
     105            2 :         std::string header{ title };
     106            2 :         if (title == "KDB") {
     107            2 :             header += " (k=" + std::to_string(k) + ", theta=" + std::to_string(theta) + ")";
     108              :         }
     109            4 :         return model.graph(header);
     110            2 :     }
     111              : }
        

Generated by: LCOV version 2.0-1

</html>