Complete KDB implementation

This commit is contained in:
Ricardo Montañana Gómez 2023-07-13 10:58:27 +02:00
parent 786d781e29
commit 3fcf1e40c9
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 28 additions and 11 deletions

View File

@ -12,7 +12,7 @@ namespace bayesnet {
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];}); sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
return indices; return indices;
} }
KDB::KDB(int k) : BaseClassifier(Network()), k(k) {} KDB::KDB(int k, float theta = 0.03) : BaseClassifier(Network()), k(k), theta(theta) {}
void KDB::train() void KDB::train()
{ {
/* /*
@ -67,17 +67,32 @@ namespace bayesnet {
model.addEdge(className, features[idx]); model.addEdge(className, features[idx]);
// 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with // 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
// the highest value for I(Xmax;X,jC). // the highest value for I(Xmax;X,jC).
// auto conditionalEdgeWeightsAccessor = conditionalEdgeWeights.accessor<float, 2>(); add_m_edges(idx, S, conditionalEdgeWeights);
// auto conditionalEdgeWeightsSorted = conditionalEdgeWeightsAccessor[idx].sort();
// auto conditionalEdgeWeightsSortedAccessor = conditionalEdgeWeightsSorted.accessor<float, 1>();
// for (auto i = 0; i < k; ++i) {
// auto index = conditionalEdgeWeightsSortedAccessor[i].item<int>();
// model.addEdge(features[idx], features[index]);
// }
// 5.5. Add Xmax to S. // 5.5. Add Xmax to S.
S.push_back(idx); S.push_back(idx);
} }
} }
void KDB::add_m_edges(int idx, vector<int>& S, Tensor& weights)
{
auto n_edges = min(k, static_cast<int>(S.size()));
auto cond_w = clone(weights);
bool exit_cond = k == 0;
int num = 0;
while (!exit_cond) {
auto max_minfo = argmax(cond_w.index({ "...", idx })).item<int>();
auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
if (belongs && cond_w.index({ idx, max_minfo }).item<float>() > theta) {
try {
model.addEdge(features[idx], features[max_minfo]);
num++;
}
catch (const invalid_argument& e) {
// Loops are not allowed
}
}
cond_w.index_put_({ "...", max_minfo }, -1);
auto candidates = cond_w.gt(theta);
exit_cond = num == n_edges || candidates.size(0) == 0;
}
}
} }

View File

@ -7,10 +7,12 @@ namespace bayesnet {
class KDB : public BaseClassifier { class KDB : public BaseClassifier {
private: private:
int k; int k;
float theta;
void add_m_edges(int idx, vector<int>& S, Tensor& weights);
protected: protected:
void train(); void train();
public: public:
KDB(int k); KDB(int k, float theta);
}; };
} }
#endif #endif