Add version and example

This commit is contained in:
2022-11-13 14:24:06 +01:00
parent f9f91c54a7
commit 6070c2159a
6 changed files with 52 additions and 19 deletions

View File

@@ -7,6 +7,7 @@ from matplotlib.testing.conftest import mpl_test_settings
from bayesclass import TAN
from .._version import __version__
@pytest.fixture
@@ -16,7 +17,7 @@ def data():
return enc.fit_transform(X), y
def test_TAN_constructor():
def test_TAN_default_hyperparameters(data):
clf = TAN()
# Test default values of hyperparameters
assert clf.simple_init
@@ -26,6 +27,21 @@ def test_TAN_constructor():
assert clf.simple_init
assert clf.show_progress
assert clf.random_state == 17
clf.fit(*data)
assert clf.head_ == 0
assert clf.class_name_ == "class"
assert clf.features_ == [
"feature_0",
"feature_1",
"feature_2",
"feature_3",
]
def test_TAN_version():
"""Check TAN version."""
clf = TAN()
assert __version__ == clf.version()
def test_TAN_random_head(data):

View File

@@ -7,5 +7,4 @@ from bayesclass import TAN
@pytest.mark.parametrize("estimator", [TAN()])
def test_all_estimators(estimator):
# return check_estimator(estimator)
assert True
return check_estimator(estimator)