Complete TAN with Maximum Spanning Tree
This commit is contained in:
@@ -19,16 +19,20 @@ namespace bayesnet {
|
||||
auto mi_value = metrics.mutualInformation(class_dataset, feature_dataset);
|
||||
mi.push_back({ i, mi_value });
|
||||
}
|
||||
sort(mi.begin(), mi.end());
|
||||
sort(mi.begin(), mi.end(), [](auto& left, auto& right) {return left.second < right.second;});
|
||||
auto root = mi[mi.size() - 1].first;
|
||||
// 2. Compute mutual information between each feature and the class
|
||||
auto weights = metrics.conditionalEdge();
|
||||
// 3. Compute the maximum spanning tree
|
||||
auto mst = metrics.maximumSpanningTree(root, weights);
|
||||
auto mst = metrics.maximumSpanningTree(features, weights, root);
|
||||
// 4. Add edges from the maximum spanning tree to the model
|
||||
for (auto i = 0; i < mst.size(); ++i) {
|
||||
auto [from, to] = mst[i];
|
||||
model.addEdge(features[from], features[to]);
|
||||
}
|
||||
// 5. Add edges from the class to all features
|
||||
for (auto feature : features) {
|
||||
model.addEdge(className, feature);
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user