LCOV - code coverage report
Current view: top level - bayesnet/classifiers - KDB.cc (source / functions) Coverage Total Hit
Test: BayesNet Coverage Report Lines: 96.3 % 54 52
Test Date: 2024-05-06 17:54:04 Functions: 100.0 % 5 5
Legend: Lines: hit not hit

            Line data    Source code
       1              : // ***************************************************************
       2              : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
       3              : // SPDX-FileType: SOURCE
       4              : // SPDX-License-Identifier: MIT
       5              : // ***************************************************************
       6              : 
       7              : #include "KDB.h"
       8              : 
       9              : namespace bayesnet {
      10          148 :     KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta)
      11              :     {
      12          444 :         validHyperparameters = { "k", "theta" };
      13              : 
      14          444 :     }
      15           12 :     void KDB::setHyperparameters(const nlohmann::json& hyperparameters_)
      16              :     {
      17           12 :         auto hyperparameters = hyperparameters_;
      18           12 :         if (hyperparameters.contains("k")) {
      19            4 :             k = hyperparameters["k"];
      20            4 :             hyperparameters.erase("k");
      21              :         }
      22           12 :         if (hyperparameters.contains("theta")) {
      23            4 :             theta = hyperparameters["theta"];
      24            4 :             hyperparameters.erase("theta");
      25              :         }
      26           12 :         Classifier::setHyperparameters(hyperparameters);
      27           12 :     }
      28           52 :     void KDB::buildModel(const torch::Tensor& weights)
      29              :     {
      30              :         /*
      31              :         1. For each feature Xi, compute mutual information, I(X;C),
      32              :         where C is the class.
      33              :         2. Compute class conditional mutual information I(Xi;XjIC), f or each
      34              :         pair of features Xi and Xj, where i#j.
      35              :         3. Let the used variable list, S, be empty.
      36              :         4. Let the DAG network being constructed, BN, begin with a single
      37              :         class node, C.
      38              :         5. Repeat until S includes all domain features
      39              :         5.1. Select feature Xmax which is not in S and has the largest value
      40              :         I(Xmax;C).
      41              :         5.2. Add a node to BN representing Xmax.
      42              :         5.3. Add an arc from C to Xmax in BN.
      43              :         5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
      44              :         the highest value for I(Xmax;X,jC).
      45              :         5.5. Add Xmax to S.
      46              :         Compute the conditional probabilility infered by the structure of BN by
      47              :         using counts from DB, and output BN.
      48              :         */
      49              :         // 1. For each feature Xi, compute mutual information, I(X;C),
      50              :         // where C is the class.
      51           52 :         addNodes();
      52          156 :         const torch::Tensor& y = dataset.index({ -1, "..." });
      53           52 :         std::vector<double> mi;
      54          396 :         for (auto i = 0; i < features.size(); i++) {
      55         1032 :             torch::Tensor firstFeature = dataset.index({ i, "..." });
      56          344 :             mi.push_back(metrics.mutualInformation(firstFeature, y, weights));
      57          344 :         }
      58              :         // 2. Compute class conditional mutual information I(Xi;XjIC), f or each
      59           52 :         auto conditionalEdgeWeights = metrics.conditionalEdge(weights);
      60              :         // 3. Let the used variable list, S, be empty.
      61           52 :         std::vector<int> S;
      62              :         // 4. Let the DAG network being constructed, BN, begin with a single
      63              :         // class node, C.
      64              :         // 5. Repeat until S includes all domain features
      65              :         // 5.1. Select feature Xmax which is not in S and has the largest value
      66              :         // I(Xmax;C).
      67           52 :         auto order = argsort(mi);
      68          396 :         for (auto idx : order) {
      69              :             // 5.2. Add a node to BN representing Xmax.
      70              :             // 5.3. Add an arc from C to Xmax in BN.
      71          344 :             model.addEdge(className, features[idx]);
      72              :             // 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
      73              :             // the highest value for I(Xmax;X,jC).
      74          344 :             add_m_edges(idx, S, conditionalEdgeWeights);
      75              :             // 5.5. Add Xmax to S.
      76          344 :             S.push_back(idx);
      77              :         }
      78          448 :     }
      79          344 :     void KDB::add_m_edges(int idx, std::vector<int>& S, torch::Tensor& weights)
      80              :     {
      81          344 :         auto n_edges = std::min(k, static_cast<int>(S.size()));
      82          344 :         auto cond_w = clone(weights);
      83          344 :         bool exit_cond = k == 0;
      84          344 :         int num = 0;
      85         1004 :         while (!exit_cond) {
      86         2640 :             auto max_minfo = argmax(cond_w.index({ idx, "..." })).item<int>();
      87          660 :             auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
      88         1764 :             if (belongs && cond_w.index({ idx, max_minfo }).item<float>() > theta) {
      89              :                 try {
      90          320 :                     model.addEdge(features[max_minfo], features[idx]);
      91          320 :                     num++;
      92              :                 }
      93            0 :                 catch (const std::invalid_argument& e) {
      94              :                     // Loops are not allowed
      95            0 :                 }
      96              :             }
      97         2640 :             cond_w.index_put_({ idx, max_minfo }, -1);
      98         1980 :             auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta);
      99          660 :             auto candidates = candidates_mask.nonzero();
     100          660 :             exit_cond = num == n_edges || candidates.size(0) == 0;
     101          660 :         }
     102         2692 :     }
     103            8 :     std::vector<std::string> KDB::graph(const std::string& title) const
     104              :     {
     105            8 :         std::string header{ title };
     106            8 :         if (title == "KDB") {
     107            8 :             header += " (k=" + std::to_string(k) + ", theta=" + std::to_string(theta) + ")";
     108              :         }
     109           16 :         return model.graph(header);
     110            8 :     }
     111              : }
        

Generated by: LCOV version 2.0-1