Complete KDB implementation
This commit is contained in:
parent
786d781e29
commit
3fcf1e40c9
35
src/KDB.cc
35
src/KDB.cc
@ -12,7 +12,7 @@ namespace bayesnet {
|
||||
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
|
||||
return indices;
|
||||
}
|
||||
KDB::KDB(int k) : BaseClassifier(Network()), k(k) {}
|
||||
KDB::KDB(int k, float theta = 0.03) : BaseClassifier(Network()), k(k), theta(theta) {}
|
||||
void KDB::train()
|
||||
{
|
||||
/*
|
||||
@ -67,17 +67,32 @@ namespace bayesnet {
|
||||
model.addEdge(className, features[idx]);
|
||||
// 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
|
||||
// the highest value for I(Xmax;X,jC).
|
||||
// auto conditionalEdgeWeightsAccessor = conditionalEdgeWeights.accessor<float, 2>();
|
||||
// auto conditionalEdgeWeightsSorted = conditionalEdgeWeightsAccessor[idx].sort();
|
||||
// auto conditionalEdgeWeightsSortedAccessor = conditionalEdgeWeightsSorted.accessor<float, 1>();
|
||||
// for (auto i = 0; i < k; ++i) {
|
||||
// auto index = conditionalEdgeWeightsSortedAccessor[i].item<int>();
|
||||
// model.addEdge(features[idx], features[index]);
|
||||
// }
|
||||
add_m_edges(idx, S, conditionalEdgeWeights);
|
||||
// 5.5. Add Xmax to S.
|
||||
S.push_back(idx);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void KDB::add_m_edges(int idx, vector<int>& S, Tensor& weights)
|
||||
{
|
||||
auto n_edges = min(k, static_cast<int>(S.size()));
|
||||
auto cond_w = clone(weights);
|
||||
bool exit_cond = k == 0;
|
||||
int num = 0;
|
||||
while (!exit_cond) {
|
||||
auto max_minfo = argmax(cond_w.index({ "...", idx })).item<int>();
|
||||
auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
|
||||
if (belongs && cond_w.index({ idx, max_minfo }).item<float>() > theta) {
|
||||
try {
|
||||
model.addEdge(features[idx], features[max_minfo]);
|
||||
num++;
|
||||
}
|
||||
catch (const invalid_argument& e) {
|
||||
// Loops are not allowed
|
||||
}
|
||||
}
|
||||
cond_w.index_put_({ "...", max_minfo }, -1);
|
||||
auto candidates = cond_w.gt(theta);
|
||||
exit_cond = num == n_edges || candidates.size(0) == 0;
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user