fix: 🐛 Fix Tests and sample mistake

This commit is contained in:
2022-12-15 12:18:10 +01:00
parent fe32ed4b2a
commit edd464311f
2 changed files with 14 additions and 11 deletions

View File

@@ -14,8 +14,11 @@ datasets = {
}
ap = argparse.ArgumentParser()
ap.add_argument("--proposal", action="store_true")
ap.add_argument("--original", dest="proposal", action="store_false")
ap.add_argument("--proposal", action="store_const", const=1)
ap.add_argument("--original", dest="proposal", action="store_const", const=0)
ap.add_argument(
"--alternative", dest="proposal", action="store_const", const=2
)
ap.add_argument("dataset", type=str, choices=datasets.keys())
args = ap.parse_args()
relative = "" if os.path.isdir("src") else ".."

View File

@@ -14,13 +14,13 @@ class FImdlpTest(unittest.TestCase):
def test_init(self):
clf = FImdlp()
self.assertEqual(-1, clf.n_jobs)
self.assertFalse(clf.proposal)
clf = FImdlp(proposal=True, n_jobs=7)
self.assertTrue(clf.proposal)
self.assertEqual(0, clf.proposal)
clf = FImdlp(proposal=1, n_jobs=7)
self.assertEqual(1, clf.proposal)
self.assertEqual(7, clf.n_jobs)
def test_fit_proposal(self):
clf = FImdlp(proposal=True)
clf = FImdlp(proposal=1)
clf.fit([[1, 2], [3, 4]], [1, 2])
self.assertEqual(clf.n_features_, 2)
self.assertListEqual(clf.X_.tolist(), [[1, 2], [3, 4]])
@@ -49,7 +49,7 @@ class FImdlpTest(unittest.TestCase):
self.assertListEqual([0, 2, 3], clf.features_)
def test_fit_original(self):
clf = FImdlp(proposal=False)
clf = FImdlp(proposal=0)
clf.fit([[1, 2], [3, 4]], [1, 2])
self.assertEqual(clf.n_features_, 2)
self.assertListEqual(clf.X_.tolist(), [[1, 2], [3, 4]])
@@ -94,7 +94,7 @@ class FImdlpTest(unittest.TestCase):
self.assertListEqual(res.tolist(), [[0, 2], [0, 4]])
def test_transform_original(self):
clf = FImdlp(proposal=False)
clf = FImdlp(proposal=0)
clf.fit([[1, 2], [3, 4]], [1, 2])
self.assertEqual(
clf.transform([[1, 2], [3, 4]]).tolist(), [[0, 0], [0, 0]]
@@ -120,11 +120,11 @@ class FImdlpTest(unittest.TestCase):
with self.assertRaises(ValueError):
clf.transform([[1, 2, 3], [4, 5, 6]])
with self.assertRaises(sklearn.exceptions.NotFittedError):
clf = FImdlp(proposal=False)
clf = FImdlp(proposal=0)
clf.transform([[1, 2], [3, 4]])
def test_transform_proposal(self):
clf = FImdlp(proposal=True)
clf = FImdlp(proposal=1)
clf.fit([[1, 2], [3, 4]], [1, 2])
self.assertEqual(
clf.transform([[1, 2], [3, 4]]).tolist(), [[0, 0], [0, 0]]
@@ -150,5 +150,5 @@ class FImdlpTest(unittest.TestCase):
with self.assertRaises(ValueError):
clf.transform([[1, 2, 3], [4, 5, 6]])
with self.assertRaises(sklearn.exceptions.NotFittedError):
clf = FImdlp(proposal=True)
clf = FImdlp(proposal=1)
clf.transform([[1, 2], [3, 4]])