Add thresold to KDB and fix error

This commit is contained in:
2022-11-16 13:15:57 +01:00
parent cad828a21b
commit 2973bc9519
3 changed files with 22 additions and 4 deletions

View File

@@ -241,8 +241,9 @@ class TAN(BayesBase):
class KDB(BayesBase):
def __init__(self, k, show_progress=False, random_state=None):
def __init__(self, k, theta=0.03, show_progress=False, random_state=None):
self.k = k
self.theta = theta
super().__init__(
show_progress=show_progress, random_state=random_state
)
@@ -289,11 +290,14 @@ class KDB(BayesBase):
def add_m_edges(dag, idx, S_nodes, conditional_weights):
n_edges = min(self.k, len(S_nodes))
cond_w = conditional_weights.copy()
exit_cond = False
exit_cond = self.k == 0
num = 0
while not exit_cond:
max_minfo = np.argmax(cond_w[idx, :])
if max_minfo in S_nodes:
if (
max_minfo in S_nodes
and cond_w[idx, max_minfo] > self.theta
):
try:
dag.add_edge(
self.features_[max_minfo], self.features_[idx]

View File

@@ -26,6 +26,7 @@ def test_KDB_default_hyperparameters(data, clf):
# Test default values of hyperparameters
assert not clf.show_progress
assert clf.random_state is None
assert clf.theta == 0.03
clf = KDB(show_progress=True, random_state=17, k=3)
assert clf.show_progress
assert clf.random_state == 17