mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-15 23:55:57 +00:00
Add thresold to KDB and fix error
This commit is contained in:
15
README.md
15
README.md
@@ -1,2 +1,15 @@
|
|||||||
# bayesclass
|
# bayesclass
|
||||||
Bayesian Estimators
|
|
||||||
|
## Bayesian Estimators
|
||||||
|
|
||||||
|
### TAN Tree augmented naive Bayes
|
||||||
|
|
||||||
|
Friedman, N., Geiger, D. & Goldszmidt, M. Bayesian Network Classifiers. Machine Learning 29, 131–163 (1997). https://doi.org/10.1023/A:1007465528199
|
||||||
|
|
||||||
|
### KDB
|
||||||
|
|
||||||
|
Mehran Sahami. 1996. Learning limited dependence Bayesian classifiers. In Proceedings of the Second International Conference on Knowledge Discovery and Data Mining (KDD'96). AAAI Press, 335–338.
|
||||||
|
|
||||||
|
### AODE Averaged One-Dependence Estimators
|
||||||
|
|
||||||
|
Webb, G., Boughton, J. & Wang, Z. Not So Naive Bayes: Aggregating One-Dependence Estimators. Mach Learn 58, 5–24 (2005). https://doi.org/10.1007/s10994-005-4258-6
|
||||||
|
@@ -241,8 +241,9 @@ class TAN(BayesBase):
|
|||||||
|
|
||||||
|
|
||||||
class KDB(BayesBase):
|
class KDB(BayesBase):
|
||||||
def __init__(self, k, show_progress=False, random_state=None):
|
def __init__(self, k, theta=0.03, show_progress=False, random_state=None):
|
||||||
self.k = k
|
self.k = k
|
||||||
|
self.theta = theta
|
||||||
super().__init__(
|
super().__init__(
|
||||||
show_progress=show_progress, random_state=random_state
|
show_progress=show_progress, random_state=random_state
|
||||||
)
|
)
|
||||||
@@ -289,11 +290,14 @@ class KDB(BayesBase):
|
|||||||
def add_m_edges(dag, idx, S_nodes, conditional_weights):
|
def add_m_edges(dag, idx, S_nodes, conditional_weights):
|
||||||
n_edges = min(self.k, len(S_nodes))
|
n_edges = min(self.k, len(S_nodes))
|
||||||
cond_w = conditional_weights.copy()
|
cond_w = conditional_weights.copy()
|
||||||
exit_cond = False
|
exit_cond = self.k == 0
|
||||||
num = 0
|
num = 0
|
||||||
while not exit_cond:
|
while not exit_cond:
|
||||||
max_minfo = np.argmax(cond_w[idx, :])
|
max_minfo = np.argmax(cond_w[idx, :])
|
||||||
if max_minfo in S_nodes:
|
if (
|
||||||
|
max_minfo in S_nodes
|
||||||
|
and cond_w[idx, max_minfo] > self.theta
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
dag.add_edge(
|
dag.add_edge(
|
||||||
self.features_[max_minfo], self.features_[idx]
|
self.features_[max_minfo], self.features_[idx]
|
||||||
|
@@ -26,6 +26,7 @@ def test_KDB_default_hyperparameters(data, clf):
|
|||||||
# Test default values of hyperparameters
|
# Test default values of hyperparameters
|
||||||
assert not clf.show_progress
|
assert not clf.show_progress
|
||||||
assert clf.random_state is None
|
assert clf.random_state is None
|
||||||
|
assert clf.theta == 0.03
|
||||||
clf = KDB(show_progress=True, random_state=17, k=3)
|
clf = KDB(show_progress=True, random_state=17, k=3)
|
||||||
assert clf.show_progress
|
assert clf.show_progress
|
||||||
assert clf.random_state == 17
|
assert clf.random_state == 17
|
||||||
|
Reference in New Issue
Block a user