Add local discretization tests

This commit is contained in:
2023-04-08 11:44:25 +02:00
parent 9843f5f8db
commit 74cd8a6aa2
3 changed files with 46 additions and 0 deletions

View File

@@ -91,6 +91,28 @@ def test_AODENew_classifier(data, clf):
assert sum(y == y_pred) == 147
def test_AODENew_local_discretization(clf, data):
expected_data = [
[-1, [0, -1], [0, -1], [0, -1]],
[[1, -1], -1, [1, -1], [1, -1]],
[[2, -1], [2, -1], -1, [2, -1]],
[[3, -1], [3, -1], [3, -1], -1],
]
clf.fit(*data)
for idx, estimator in enumerate(clf.estimators_):
expected = expected_data[idx]
for feature in range(4):
computed = estimator.discretizer_.target_[feature]
if type(computed) == list:
for j, k in zip(expected[feature], computed):
assert j == k
else:
assert (
expected[feature]
== estimator.discretizer_.target_[feature]
)
def test_AODENew_wrong_num_features(data, clf):
with pytest.raises(
ValueError,

View File

@@ -72,6 +72,21 @@ def test_KDBNew_classifier(data, clf):
assert sum(y == y_pred) == 148
def test_KDBNew_local_discretization(clf, data):
expected = [[1, -1], -1, [0, 1, 3, -1], [1, 0, -1]]
clf.fit(*data)
for feature in range(4):
computed = clf.estimator_.discretizer_.target_[feature]
if type(computed) == list:
for j, k in zip(expected[feature], computed):
assert j == k
else:
assert (
expected[feature]
== clf.estimator_.discretizer_.target_[feature]
)
@image_comparison(
baseline_images=["line_dashes_KDBNew"],
remove_text=True,

View File

@@ -63,6 +63,15 @@ def test_TANNew_random_head(clf, data):
assert clf.head_ == 3
def test_TANNew_local_discretization(clf, data):
expected = [-1, [0, -1], [0, -1], [1, -1]]
clf.fit(*data)
for feature in range(4):
assert (
expected[feature] == clf.estimator_.discretizer_.target_[feature]
)
def test_TANNew_classifier(data, clf):
clf.fit(*data)
attribs = [