diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index e0eb4cc..5c0edd4 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -260,6 +260,25 @@ class KDB(BayesBase): expected_args = ["class_name", "features", "state_names"] return self._check_params_fit(X, y, expected_args, kwargs) + def _add_m_edges(self, dag, idx, S_nodes, conditional_weights): + n_edges = min(self.k, len(S_nodes)) + cond_w = conditional_weights.copy() + exit_cond = self.k == 0 + num = 0 + while not exit_cond: + max_minfo = np.argmax(cond_w[idx, :]) + if max_minfo in S_nodes and cond_w[idx, max_minfo] > self.theta: + try: + dag.add_edge( + self.features_[max_minfo], self.features_[idx] + ) + num += 1 + except ValueError: + # Loops are not allowed + pass + cond_w[idx, max_minfo] = -1 + exit_cond = num == n_edges or np.all(cond_w[idx, :] <= 0) + def _build(self): """ 1. For each feature Xi, compute mutual information, I(X;;C), where C is the class. @@ -275,28 +294,6 @@ class KDB(BayesBase): Compute the conditional probabilility infered by the structure of BN by using counts from DB, and output BN. """ - def add_m_edges(dag, idx, S_nodes, conditional_weights): - n_edges = min(self.k, len(S_nodes)) - cond_w = conditional_weights.copy() - exit_cond = self.k == 0 - num = 0 - while not exit_cond: - max_minfo = np.argmax(cond_w[idx, :]) - if ( - max_minfo in S_nodes - and cond_w[idx, max_minfo] > self.theta - ): - try: - dag.add_edge( - self.features_[max_minfo], self.features_[idx] - ) - num += 1 - except ValueError: - # Loops are not allowed - pass - cond_w[idx, max_minfo] = -1 - exit_cond = num == n_edges or np.all(cond_w[idx, :] <= 0) - # 1. get the mutual information between each feature and the class mutual = mutual_info_classif(self.X_, self.y_, discrete_features=True) # 2. symmetric matrix where each element represents I(X, Y| class_node) @@ -318,7 +315,7 @@ class KDB(BayesBase): # 5.3 dag.add_edge(self.class_name_, feature) # 5.4 - add_m_edges(dag, idx, S_nodes, conditional_weights) + self._add_m_edges(dag, idx, S_nodes, conditional_weights) # 5.5 S_nodes.append(idx) self.dag_ = dag diff --git a/bayesclass/tests/test_KDB.py b/bayesclass/tests/test_KDB.py index d776519..10f93d4 100644 --- a/bayesclass/tests/test_KDB.py +++ b/bayesclass/tests/test_KDB.py @@ -4,6 +4,7 @@ from sklearn.datasets import load_iris from sklearn.preprocessing import KBinsDiscretizer from matplotlib.testing.decorators import image_comparison from matplotlib.testing.conftest import mpl_test_settings +from pgmpy.models import BayesianNetwork from bayesclass.clfs import KDB @@ -94,3 +95,19 @@ def test_KDB_error_size_predict(data, clf): with pytest.raises(ValueError): X_diff_size = np.ones((10, X.shape[1] + 1)) clf.predict(X_diff_size) + + +def test_KDB_dont_do_cycles(): + clf = KDB(k=4) + dag = BayesianNetwork() + clf.features_ = ["feature_0", "feature_1", "feature_2", "feature_3"] + nodes = list(range(4)) + weights = np.ones((4, 4)) + for idx in range(1, 4): + dag.add_edge(clf.features_[0], clf.features_[idx]) + dag.add_edge(clf.features_[1], clf.features_[2]) + dag.add_edge(clf.features_[1], clf.features_[3]) + dag.add_edge(clf.features_[2], clf.features_[3]) + for idx in range(4): + clf._add_m_edges(dag, idx, nodes, weights) + assert len(dag.edges()) == 6