Update tests and module mdlp version

This commit is contained in:
2023-04-11 19:33:57 +02:00
parent 0768d68a36
commit d04cb389c0
2 changed files with 20 additions and 10 deletions

View File

@@ -136,22 +136,32 @@ class FImdlpTest(unittest.TestCase):
self.assertListEqual(expected, computed) self.assertListEqual(expected, computed)
def test_join_fit(self): def test_join_fit(self):
y = np.array([b"f0", b"f0", b"f2", b"f3", b"f4"]) y = np.array([b"f0", b"f0", b"f2", b"f3", b"f3", b"f4", b"f4"])
x = np.array( x = np.array(
[ [
[0, 1, 2, 3, 4], [0, 1, 2, 3, 4, 5],
[0, 1, 2, 3, 4], [0, 2, 2, 3, 4, 5],
[1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 5],
[2, 3, 4, 5, 6], [2, 3, 4, 5, 6, 6],
[3, 4, 5, 6, 7], [3, 4, 5, 6, 7, 7],
[1, 2, 2, 3, 5, 7],
[1, 3, 4, 4, 4, 7],
] ]
) )
expected = [0, 0, 1, 2, 2] expected = [0, 1, 1, 2, 2, 1, 2]
clf = FImdlp() clf = FImdlp()
clf.fit(x, factorize(y)) clf.fit(x, factorize(y))
computed = clf.join_fit([0, 2], 1, x) computed = clf.join_fit([0, 2, 3, 4], 1, x)
self.assertListEqual(computed.tolist(), expected) self.assertListEqual(computed.tolist(), expected)
expected_y = [b"002", b"002", b"113", b"224", b"335"] expected_y = [
b"00234",
b"00234",
b"11345",
b"22456",
b"23567",
b"31235",
b"31444",
]
self.assertListEqual(expected_y, clf.y_join_) self.assertListEqual(expected_y, clf.y_join_)
def test_join_fit_error(self): def test_join_fit_error(self):