Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "KDB.h"
8 :
9 : namespace bayesnet {
10 37 : KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta)
11 : {
12 111 : validHyperparameters = { "k", "theta" };
13 :
14 111 : }
15 3 : void KDB::setHyperparameters(const nlohmann::json& hyperparameters_)
16 : {
17 3 : auto hyperparameters = hyperparameters_;
18 3 : if (hyperparameters.contains("k")) {
19 1 : k = hyperparameters["k"];
20 1 : hyperparameters.erase("k");
21 : }
22 3 : if (hyperparameters.contains("theta")) {
23 1 : theta = hyperparameters["theta"];
24 1 : hyperparameters.erase("theta");
25 : }
26 3 : Classifier::setHyperparameters(hyperparameters);
27 3 : }
28 13 : void KDB::buildModel(const torch::Tensor& weights)
29 : {
30 : /*
31 : 1. For each feature Xi, compute mutual information, I(X;C),
32 : where C is the class.
33 : 2. Compute class conditional mutual information I(Xi;XjIC), f or each
34 : pair of features Xi and Xj, where i#j.
35 : 3. Let the used variable list, S, be empty.
36 : 4. Let the DAG network being constructed, BN, begin with a single
37 : class node, C.
38 : 5. Repeat until S includes all domain features
39 : 5.1. Select feature Xmax which is not in S and has the largest value
40 : I(Xmax;C).
41 : 5.2. Add a node to BN representing Xmax.
42 : 5.3. Add an arc from C to Xmax in BN.
43 : 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
44 : the highest value for I(Xmax;X,jC).
45 : 5.5. Add Xmax to S.
46 : Compute the conditional probabilility infered by the structure of BN by
47 : using counts from DB, and output BN.
48 : */
49 : // 1. For each feature Xi, compute mutual information, I(X;C),
50 : // where C is the class.
51 13 : addNodes();
52 39 : const torch::Tensor& y = dataset.index({ -1, "..." });
53 13 : std::vector<double> mi;
54 99 : for (auto i = 0; i < features.size(); i++) {
55 258 : torch::Tensor firstFeature = dataset.index({ i, "..." });
56 86 : mi.push_back(metrics.mutualInformation(firstFeature, y, weights));
57 86 : }
58 : // 2. Compute class conditional mutual information I(Xi;XjIC), f or each
59 13 : auto conditionalEdgeWeights = metrics.conditionalEdge(weights);
60 : // 3. Let the used variable list, S, be empty.
61 13 : std::vector<int> S;
62 : // 4. Let the DAG network being constructed, BN, begin with a single
63 : // class node, C.
64 : // 5. Repeat until S includes all domain features
65 : // 5.1. Select feature Xmax which is not in S and has the largest value
66 : // I(Xmax;C).
67 13 : auto order = argsort(mi);
68 99 : for (auto idx : order) {
69 : // 5.2. Add a node to BN representing Xmax.
70 : // 5.3. Add an arc from C to Xmax in BN.
71 86 : model.addEdge(className, features[idx]);
72 : // 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
73 : // the highest value for I(Xmax;X,jC).
74 86 : add_m_edges(idx, S, conditionalEdgeWeights);
75 : // 5.5. Add Xmax to S.
76 86 : S.push_back(idx);
77 : }
78 112 : }
79 86 : void KDB::add_m_edges(int idx, std::vector<int>& S, torch::Tensor& weights)
80 : {
81 86 : auto n_edges = std::min(k, static_cast<int>(S.size()));
82 86 : auto cond_w = clone(weights);
83 86 : bool exit_cond = k == 0;
84 86 : int num = 0;
85 251 : while (!exit_cond) {
86 660 : auto max_minfo = argmax(cond_w.index({ idx, "..." })).item<int>();
87 165 : auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
88 441 : if (belongs && cond_w.index({ idx, max_minfo }).item<float>() > theta) {
89 : try {
90 80 : model.addEdge(features[max_minfo], features[idx]);
91 80 : num++;
92 : }
93 0 : catch (const std::invalid_argument& e) {
94 : // Loops are not allowed
95 0 : }
96 : }
97 660 : cond_w.index_put_({ idx, max_minfo }, -1);
98 495 : auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta);
99 165 : auto candidates = candidates_mask.nonzero();
100 165 : exit_cond = num == n_edges || candidates.size(0) == 0;
101 165 : }
102 673 : }
103 2 : std::vector<std::string> KDB::graph(const std::string& title) const
104 : {
105 2 : std::string header{ title };
106 2 : if (title == "KDB") {
107 2 : header += " (k=" + std::to_string(k) + ", theta=" + std::to_string(theta) + ")";
108 : }
109 4 : return model.graph(header);
110 2 : }
111 : }
|