Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "KDB.h"
8 :
9 : namespace bayesnet {
10 74 : KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta)
11 : {
12 222 : validHyperparameters = { "k", "theta" };
13 :
14 222 : }
15 6 : void KDB::setHyperparameters(const nlohmann::json& hyperparameters_)
16 : {
17 6 : auto hyperparameters = hyperparameters_;
18 6 : if (hyperparameters.contains("k")) {
19 2 : k = hyperparameters["k"];
20 2 : hyperparameters.erase("k");
21 : }
22 6 : if (hyperparameters.contains("theta")) {
23 2 : theta = hyperparameters["theta"];
24 2 : hyperparameters.erase("theta");
25 : }
26 6 : Classifier::setHyperparameters(hyperparameters);
27 6 : }
28 26 : void KDB::buildModel(const torch::Tensor& weights)
29 : {
30 : /*
31 : 1. For each feature Xi, compute mutual information, I(X;C),
32 : where C is the class.
33 : 2. Compute class conditional mutual information I(Xi;XjIC), f or each
34 : pair of features Xi and Xj, where i#j.
35 : 3. Let the used variable list, S, be empty.
36 : 4. Let the DAG network being constructed, BN, begin with a single
37 : class node, C.
38 : 5. Repeat until S includes all domain features
39 : 5.1. Select feature Xmax which is not in S and has the largest value
40 : I(Xmax;C).
41 : 5.2. Add a node to BN representing Xmax.
42 : 5.3. Add an arc from C to Xmax in BN.
43 : 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
44 : the highest value for I(Xmax;X,jC).
45 : 5.5. Add Xmax to S.
46 : Compute the conditional probabilility infered by the structure of BN by
47 : using counts from DB, and output BN.
48 : */
49 : // 1. For each feature Xi, compute mutual information, I(X;C),
50 : // where C is the class.
51 26 : addNodes();
52 78 : const torch::Tensor& y = dataset.index({ -1, "..." });
53 26 : std::vector<double> mi;
54 198 : for (auto i = 0; i < features.size(); i++) {
55 516 : torch::Tensor firstFeature = dataset.index({ i, "..." });
56 172 : mi.push_back(metrics.mutualInformation(firstFeature, y, weights));
57 172 : }
58 : // 2. Compute class conditional mutual information I(Xi;XjIC), f or each
59 26 : auto conditionalEdgeWeights = metrics.conditionalEdge(weights);
60 : // 3. Let the used variable list, S, be empty.
61 26 : std::vector<int> S;
62 : // 4. Let the DAG network being constructed, BN, begin with a single
63 : // class node, C.
64 : // 5. Repeat until S includes all domain features
65 : // 5.1. Select feature Xmax which is not in S and has the largest value
66 : // I(Xmax;C).
67 26 : auto order = argsort(mi);
68 198 : for (auto idx : order) {
69 : // 5.2. Add a node to BN representing Xmax.
70 : // 5.3. Add an arc from C to Xmax in BN.
71 172 : model.addEdge(className, features[idx]);
72 : // 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
73 : // the highest value for I(Xmax;X,jC).
74 172 : add_m_edges(idx, S, conditionalEdgeWeights);
75 : // 5.5. Add Xmax to S.
76 172 : S.push_back(idx);
77 : }
78 224 : }
79 172 : void KDB::add_m_edges(int idx, std::vector<int>& S, torch::Tensor& weights)
80 : {
81 172 : auto n_edges = std::min(k, static_cast<int>(S.size()));
82 172 : auto cond_w = clone(weights);
83 172 : bool exit_cond = k == 0;
84 172 : int num = 0;
85 502 : while (!exit_cond) {
86 1320 : auto max_minfo = argmax(cond_w.index({ idx, "..." })).item<int>();
87 330 : auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
88 882 : if (belongs && cond_w.index({ idx, max_minfo }).item<float>() > theta) {
89 : try {
90 160 : model.addEdge(features[max_minfo], features[idx]);
91 160 : num++;
92 : }
93 0 : catch (const std::invalid_argument& e) {
94 : // Loops are not allowed
95 0 : }
96 : }
97 1320 : cond_w.index_put_({ idx, max_minfo }, -1);
98 990 : auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta);
99 330 : auto candidates = candidates_mask.nonzero();
100 330 : exit_cond = num == n_edges || candidates.size(0) == 0;
101 330 : }
102 1346 : }
103 4 : std::vector<std::string> KDB::graph(const std::string& title) const
104 : {
105 4 : std::string header{ title };
106 4 : if (title == "KDB") {
107 4 : header += " (k=" + std::to_string(k) + ", theta=" + std::to_string(theta) + ")";
108 : }
109 8 : return model.graph(header);
110 4 : }
111 : }
|