Line data Source code
1 : // ***************************************************************
2 : // SPDX-FileCopyrightText: Copyright 2024 Ricardo Montañana Gómez
3 : // SPDX-FileType: SOURCE
4 : // SPDX-License-Identifier: MIT
5 : // ***************************************************************
6 :
7 : #include "KDB.h"
8 :
9 : namespace bayesnet {
10 148 : KDB::KDB(int k, float theta) : Classifier(Network()), k(k), theta(theta)
11 : {
12 444 : validHyperparameters = { "k", "theta" };
13 :
14 444 : }
15 12 : void KDB::setHyperparameters(const nlohmann::json& hyperparameters_)
16 : {
17 12 : auto hyperparameters = hyperparameters_;
18 12 : if (hyperparameters.contains("k")) {
19 4 : k = hyperparameters["k"];
20 4 : hyperparameters.erase("k");
21 : }
22 12 : if (hyperparameters.contains("theta")) {
23 4 : theta = hyperparameters["theta"];
24 4 : hyperparameters.erase("theta");
25 : }
26 12 : Classifier::setHyperparameters(hyperparameters);
27 12 : }
28 52 : void KDB::buildModel(const torch::Tensor& weights)
29 : {
30 : /*
31 : 1. For each feature Xi, compute mutual information, I(X;C),
32 : where C is the class.
33 : 2. Compute class conditional mutual information I(Xi;XjIC), f or each
34 : pair of features Xi and Xj, where i#j.
35 : 3. Let the used variable list, S, be empty.
36 : 4. Let the DAG network being constructed, BN, begin with a single
37 : class node, C.
38 : 5. Repeat until S includes all domain features
39 : 5.1. Select feature Xmax which is not in S and has the largest value
40 : I(Xmax;C).
41 : 5.2. Add a node to BN representing Xmax.
42 : 5.3. Add an arc from C to Xmax in BN.
43 : 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
44 : the highest value for I(Xmax;X,jC).
45 : 5.5. Add Xmax to S.
46 : Compute the conditional probabilility infered by the structure of BN by
47 : using counts from DB, and output BN.
48 : */
49 : // 1. For each feature Xi, compute mutual information, I(X;C),
50 : // where C is the class.
51 52 : addNodes();
52 156 : const torch::Tensor& y = dataset.index({ -1, "..." });
53 52 : std::vector<double> mi;
54 396 : for (auto i = 0; i < features.size(); i++) {
55 1032 : torch::Tensor firstFeature = dataset.index({ i, "..." });
56 344 : mi.push_back(metrics.mutualInformation(firstFeature, y, weights));
57 344 : }
58 : // 2. Compute class conditional mutual information I(Xi;XjIC), f or each
59 52 : auto conditionalEdgeWeights = metrics.conditionalEdge(weights);
60 : // 3. Let the used variable list, S, be empty.
61 52 : std::vector<int> S;
62 : // 4. Let the DAG network being constructed, BN, begin with a single
63 : // class node, C.
64 : // 5. Repeat until S includes all domain features
65 : // 5.1. Select feature Xmax which is not in S and has the largest value
66 : // I(Xmax;C).
67 52 : auto order = argsort(mi);
68 396 : for (auto idx : order) {
69 : // 5.2. Add a node to BN representing Xmax.
70 : // 5.3. Add an arc from C to Xmax in BN.
71 344 : model.addEdge(className, features[idx]);
72 : // 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
73 : // the highest value for I(Xmax;X,jC).
74 344 : add_m_edges(idx, S, conditionalEdgeWeights);
75 : // 5.5. Add Xmax to S.
76 344 : S.push_back(idx);
77 : }
78 448 : }
79 344 : void KDB::add_m_edges(int idx, std::vector<int>& S, torch::Tensor& weights)
80 : {
81 344 : auto n_edges = std::min(k, static_cast<int>(S.size()));
82 344 : auto cond_w = clone(weights);
83 344 : bool exit_cond = k == 0;
84 344 : int num = 0;
85 1004 : while (!exit_cond) {
86 2640 : auto max_minfo = argmax(cond_w.index({ idx, "..." })).item<int>();
87 660 : auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
88 1764 : if (belongs && cond_w.index({ idx, max_minfo }).item<float>() > theta) {
89 : try {
90 320 : model.addEdge(features[max_minfo], features[idx]);
91 320 : num++;
92 : }
93 0 : catch (const std::invalid_argument& e) {
94 : // Loops are not allowed
95 0 : }
96 : }
97 2640 : cond_w.index_put_({ idx, max_minfo }, -1);
98 1980 : auto candidates_mask = cond_w.index({ idx, "..." }).gt(theta);
99 660 : auto candidates = candidates_mask.nonzero();
100 660 : exit_cond = num == n_edges || candidates.size(0) == 0;
101 660 : }
102 2692 : }
103 8 : std::vector<std::string> KDB::graph(const std::string& title) const
104 : {
105 8 : std::string header{ title };
106 8 : if (title == "KDB") {
107 8 : header += " (k=" + std::to_string(k) + ", theta=" + std::to_string(theta) + ")";
108 : }
109 16 : return model.graph(header);
110 8 : }
111 : }
|