diff --git a/README.md b/README.md index 32f5b81..9638174 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,15 @@ # bayesclass -Bayesian Estimators + +## Bayesian Estimators + +### TAN Tree augmented naive Bayes + +Friedman, N., Geiger, D. & Goldszmidt, M. Bayesian Network Classifiers. Machine Learning 29, 131–163 (1997). https://doi.org/10.1023/A:1007465528199 + +### KDB + +Mehran Sahami. 1996. Learning limited dependence Bayesian classifiers. In Proceedings of the Second International Conference on Knowledge Discovery and Data Mining (KDD'96). AAAI Press, 335–338. + +### AODE Averaged One-Dependence Estimators + +Webb, G., Boughton, J. & Wang, Z. Not So Naive Bayes: Aggregating One-Dependence Estimators. Mach Learn 58, 5–24 (2005). https://doi.org/10.1007/s10994-005-4258-6 diff --git a/bayesclass/bayesclass.py b/bayesclass/bayesclass.py index 284612c..2f4ec19 100644 --- a/bayesclass/bayesclass.py +++ b/bayesclass/bayesclass.py @@ -241,8 +241,9 @@ class TAN(BayesBase): class KDB(BayesBase): - def __init__(self, k, show_progress=False, random_state=None): + def __init__(self, k, theta=0.03, show_progress=False, random_state=None): self.k = k + self.theta = theta super().__init__( show_progress=show_progress, random_state=random_state ) @@ -289,11 +290,14 @@ class KDB(BayesBase): def add_m_edges(dag, idx, S_nodes, conditional_weights): n_edges = min(self.k, len(S_nodes)) cond_w = conditional_weights.copy() - exit_cond = False + exit_cond = self.k == 0 num = 0 while not exit_cond: max_minfo = np.argmax(cond_w[idx, :]) - if max_minfo in S_nodes: + if ( + max_minfo in S_nodes + and cond_w[idx, max_minfo] > self.theta + ): try: dag.add_edge( self.features_[max_minfo], self.features_[idx] diff --git a/bayesclass/tests/test_KDB.py b/bayesclass/tests/test_KDB.py index a42e9aa..e280dcd 100644 --- a/bayesclass/tests/test_KDB.py +++ b/bayesclass/tests/test_KDB.py @@ -26,6 +26,7 @@ def test_KDB_default_hyperparameters(data, clf): # Test default values of hyperparameters assert not clf.show_progress assert clf.random_state is None + assert clf.theta == 0.03 clf = KDB(show_progress=True, random_state=17, k=3) assert clf.show_progress assert clf.random_state == 17