From 3fcf1e40c98702b66186306e787d9c96ac3efe59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 13 Jul 2023 10:58:27 +0200 Subject: [PATCH] Complete KDB implementation --- src/KDB.cc | 35 +++++++++++++++++++++++++---------- src/KDB.h | 4 +++- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/src/KDB.cc b/src/KDB.cc index 178f1f0..152c4e0 100644 --- a/src/KDB.cc +++ b/src/KDB.cc @@ -12,7 +12,7 @@ namespace bayesnet { sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];}); return indices; } - KDB::KDB(int k) : BaseClassifier(Network()), k(k) {} + KDB::KDB(int k, float theta = 0.03) : BaseClassifier(Network()), k(k), theta(theta) {} void KDB::train() { /* @@ -67,17 +67,32 @@ namespace bayesnet { model.addEdge(className, features[idx]); // 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with // the highest value for I(Xmax;X,jC). - // auto conditionalEdgeWeightsAccessor = conditionalEdgeWeights.accessor(); - // auto conditionalEdgeWeightsSorted = conditionalEdgeWeightsAccessor[idx].sort(); - // auto conditionalEdgeWeightsSortedAccessor = conditionalEdgeWeightsSorted.accessor(); - // for (auto i = 0; i < k; ++i) { - // auto index = conditionalEdgeWeightsSortedAccessor[i].item(); - // model.addEdge(features[idx], features[index]); - // } + add_m_edges(idx, S, conditionalEdgeWeights); // 5.5. Add Xmax to S. S.push_back(idx); } - } - + void KDB::add_m_edges(int idx, vector& S, Tensor& weights) + { + auto n_edges = min(k, static_cast(S.size())); + auto cond_w = clone(weights); + bool exit_cond = k == 0; + int num = 0; + while (!exit_cond) { + auto max_minfo = argmax(cond_w.index({ "...", idx })).item(); + auto belongs = find(S.begin(), S.end(), max_minfo) != S.end(); + if (belongs && cond_w.index({ idx, max_minfo }).item() > theta) { + try { + model.addEdge(features[idx], features[max_minfo]); + num++; + } + catch (const invalid_argument& e) { + // Loops are not allowed + } + } + cond_w.index_put_({ "...", max_minfo }, -1); + auto candidates = cond_w.gt(theta); + exit_cond = num == n_edges || candidates.size(0) == 0; + } + } } \ No newline at end of file diff --git a/src/KDB.h b/src/KDB.h index 2714cb7..5830140 100644 --- a/src/KDB.h +++ b/src/KDB.h @@ -7,10 +7,12 @@ namespace bayesnet { class KDB : public BaseClassifier { private: int k; + float theta; + void add_m_edges(int idx, vector& S, Tensor& weights); protected: void train(); public: - KDB(int k); + KDB(int k, float theta); }; } #endif \ No newline at end of file