Update tests and module mdlp version

This commit is contained in:
2023-04-11 19:33:57 +02:00
parent 0768d68a36
commit d04cb389c0
2 changed files with 20 additions and 10 deletions

View File

@@ -136,22 +136,32 @@ class FImdlpTest(unittest.TestCase):
self.assertListEqual(expected, computed)
def test_join_fit(self):
y = np.array([b"f0", b"f0", b"f2", b"f3", b"f4"])
y = np.array([b"f0", b"f0", b"f2", b"f3", b"f3", b"f4", b"f4"])
x = np.array(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[1, 2, 3, 4, 5],
[2, 3, 4, 5, 6],
[3, 4, 5, 6, 7],
[0, 1, 2, 3, 4, 5],
[0, 2, 2, 3, 4, 5],
[1, 2, 3, 4, 5, 5],
[2, 3, 4, 5, 6, 6],
[3, 4, 5, 6, 7, 7],
[1, 2, 2, 3, 5, 7],
[1, 3, 4, 4, 4, 7],
]
)
expected = [0, 0, 1, 2, 2]
expected = [0, 1, 1, 2, 2, 1, 2]
clf = FImdlp()
clf.fit(x, factorize(y))
computed = clf.join_fit([0, 2], 1, x)
computed = clf.join_fit([0, 2, 3, 4], 1, x)
self.assertListEqual(computed.tolist(), expected)
expected_y = [b"002", b"002", b"113", b"224", b"335"]
expected_y = [
b"00234",
b"00234",
b"11345",
b"22456",
b"23567",
b"31235",
b"31444",
]
self.assertListEqual(expected_y, clf.y_join_)
def test_join_fit_error(self):