diff --git a/bayesclass/tests/test_AODENew.py b/bayesclass/tests/test_AODENew.py index 7da4414..72a73cb 100644 --- a/bayesclass/tests/test_AODENew.py +++ b/bayesclass/tests/test_AODENew.py @@ -91,6 +91,28 @@ def test_AODENew_classifier(data, clf): assert sum(y == y_pred) == 147 +def test_AODENew_local_discretization(clf, data): + expected_data = [ + [-1, [0, -1], [0, -1], [0, -1]], + [[1, -1], -1, [1, -1], [1, -1]], + [[2, -1], [2, -1], -1, [2, -1]], + [[3, -1], [3, -1], [3, -1], -1], + ] + clf.fit(*data) + for idx, estimator in enumerate(clf.estimators_): + expected = expected_data[idx] + for feature in range(4): + computed = estimator.discretizer_.target_[feature] + if type(computed) == list: + for j, k in zip(expected[feature], computed): + assert j == k + else: + assert ( + expected[feature] + == estimator.discretizer_.target_[feature] + ) + + def test_AODENew_wrong_num_features(data, clf): with pytest.raises( ValueError, diff --git a/bayesclass/tests/test_KDBNew.py b/bayesclass/tests/test_KDBNew.py index e8948ae..b14d731 100644 --- a/bayesclass/tests/test_KDBNew.py +++ b/bayesclass/tests/test_KDBNew.py @@ -72,6 +72,21 @@ def test_KDBNew_classifier(data, clf): assert sum(y == y_pred) == 148 +def test_KDBNew_local_discretization(clf, data): + expected = [[1, -1], -1, [0, 1, 3, -1], [1, 0, -1]] + clf.fit(*data) + for feature in range(4): + computed = clf.estimator_.discretizer_.target_[feature] + if type(computed) == list: + for j, k in zip(expected[feature], computed): + assert j == k + else: + assert ( + expected[feature] + == clf.estimator_.discretizer_.target_[feature] + ) + + @image_comparison( baseline_images=["line_dashes_KDBNew"], remove_text=True, diff --git a/bayesclass/tests/test_TANNew.py b/bayesclass/tests/test_TANNew.py index 2208f26..406222e 100644 --- a/bayesclass/tests/test_TANNew.py +++ b/bayesclass/tests/test_TANNew.py @@ -63,6 +63,15 @@ def test_TANNew_random_head(clf, data): assert clf.head_ == 3 +def test_TANNew_local_discretization(clf, data): + expected = [-1, [0, -1], [0, -1], [1, -1]] + clf.fit(*data) + for feature in range(4): + assert ( + expected[feature] == clf.estimator_.discretizer_.target_[feature] + ) + + def test_TANNew_classifier(data, clf): clf.fit(*data) attribs = [