Complete KDB implementation
This commit is contained in:
parent
786d781e29
commit
3fcf1e40c9
35
src/KDB.cc
35
src/KDB.cc
@ -12,7 +12,7 @@ namespace bayesnet {
|
|||||||
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
|
sort(indices.begin(), indices.end(), [&nums](int i, int j) {return nums[i] > nums[j];});
|
||||||
return indices;
|
return indices;
|
||||||
}
|
}
|
||||||
KDB::KDB(int k) : BaseClassifier(Network()), k(k) {}
|
KDB::KDB(int k, float theta = 0.03) : BaseClassifier(Network()), k(k), theta(theta) {}
|
||||||
void KDB::train()
|
void KDB::train()
|
||||||
{
|
{
|
||||||
/*
|
/*
|
||||||
@ -67,17 +67,32 @@ namespace bayesnet {
|
|||||||
model.addEdge(className, features[idx]);
|
model.addEdge(className, features[idx]);
|
||||||
// 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
|
// 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with
|
||||||
// the highest value for I(Xmax;X,jC).
|
// the highest value for I(Xmax;X,jC).
|
||||||
// auto conditionalEdgeWeightsAccessor = conditionalEdgeWeights.accessor<float, 2>();
|
add_m_edges(idx, S, conditionalEdgeWeights);
|
||||||
// auto conditionalEdgeWeightsSorted = conditionalEdgeWeightsAccessor[idx].sort();
|
|
||||||
// auto conditionalEdgeWeightsSortedAccessor = conditionalEdgeWeightsSorted.accessor<float, 1>();
|
|
||||||
// for (auto i = 0; i < k; ++i) {
|
|
||||||
// auto index = conditionalEdgeWeightsSortedAccessor[i].item<int>();
|
|
||||||
// model.addEdge(features[idx], features[index]);
|
|
||||||
// }
|
|
||||||
// 5.5. Add Xmax to S.
|
// 5.5. Add Xmax to S.
|
||||||
S.push_back(idx);
|
S.push_back(idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
void KDB::add_m_edges(int idx, vector<int>& S, Tensor& weights)
|
||||||
|
{
|
||||||
|
auto n_edges = min(k, static_cast<int>(S.size()));
|
||||||
|
auto cond_w = clone(weights);
|
||||||
|
bool exit_cond = k == 0;
|
||||||
|
int num = 0;
|
||||||
|
while (!exit_cond) {
|
||||||
|
auto max_minfo = argmax(cond_w.index({ "...", idx })).item<int>();
|
||||||
|
auto belongs = find(S.begin(), S.end(), max_minfo) != S.end();
|
||||||
|
if (belongs && cond_w.index({ idx, max_minfo }).item<float>() > theta) {
|
||||||
|
try {
|
||||||
|
model.addEdge(features[idx], features[max_minfo]);
|
||||||
|
num++;
|
||||||
|
}
|
||||||
|
catch (const invalid_argument& e) {
|
||||||
|
// Loops are not allowed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cond_w.index_put_({ "...", max_minfo }, -1);
|
||||||
|
auto candidates = cond_w.gt(theta);
|
||||||
|
exit_cond = num == n_edges || candidates.size(0) == 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -7,10 +7,12 @@ namespace bayesnet {
|
|||||||
class KDB : public BaseClassifier {
|
class KDB : public BaseClassifier {
|
||||||
private:
|
private:
|
||||||
int k;
|
int k;
|
||||||
|
float theta;
|
||||||
|
void add_m_edges(int idx, vector<int>& S, Tensor& weights);
|
||||||
protected:
|
protected:
|
||||||
void train();
|
void train();
|
||||||
public:
|
public:
|
||||||
KDB(int k);
|
KDB(int k, float theta);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
Loading…
Reference in New Issue
Block a user