Update head hyperparam to use highest weight

This commit is contained in:
2022-11-07 00:51:19 +01:00
parent 8c03fc6b67
commit 02110f7608
4 changed files with 630 additions and 279 deletions

View File

@@ -16,12 +16,26 @@ def data():
return enc.fit_transform(X), y
def test_TAN_classifier(data):
def test_TAN_constructor():
clf = TAN()
# Test default values of hyperparameters
assert not clf.simple_init
assert not clf.show_progress
assert clf.random_state is None
clf = TAN(simple_init=True, show_progress=True, random_state=17)
assert clf.simple_init
assert clf.show_progress
assert clf.random_state == 17
def test_TAN_random_head(data):
clf = TAN(random_state=17)
clf.fit(*data, head="random")
assert clf.head_ == 3
def test_TAN_classifier(data):
clf = TAN()
clf.fit(*data)
attribs = ["classes_", "X_", "y_", "head_", "features_", "class_name_"]