Add example of usage

This commit is contained in:
2022-11-12 13:16:11 +01:00
parent 180365d727
commit 8ca3646f8c
12 changed files with 88 additions and 1293 deletions

View File

@@ -19,7 +19,7 @@ def data():
def test_TAN_constructor():
clf = TAN()
# Test default values of hyperparameters
assert not clf.simple_init
assert clf.simple_init
assert not clf.show_progress
assert clf.random_state is None
clf = TAN(simple_init=True, show_progress=True, random_state=17)
@@ -34,6 +34,14 @@ def test_TAN_random_head(data):
assert clf.head_ == 3
def test_TAN_dag_initializer(data):
clf_not_simple = TAN(simple_init=False)
clf_simple = TAN(simple_init=True)
clf_not_simple.fit(*data, head=0)
clf_simple.fit(*data, head=0)
assert clf_simple.dag_.edges == clf_not_simple.dag_.edges
def test_TAN_classifier(data):
clf = TAN()