diff --git a/samples/sample.py b/samples/sample.py index e7f5fca..7ab2a19 100644 --- a/samples/sample.py +++ b/samples/sample.py @@ -14,8 +14,11 @@ datasets = { } ap = argparse.ArgumentParser() -ap.add_argument("--proposal", action="store_true") -ap.add_argument("--original", dest="proposal", action="store_false") +ap.add_argument("--proposal", action="store_const", const=1) +ap.add_argument("--original", dest="proposal", action="store_const", const=0) +ap.add_argument( + "--alternative", dest="proposal", action="store_const", const=2 +) ap.add_argument("dataset", type=str, choices=datasets.keys()) args = ap.parse_args() relative = "" if os.path.isdir("src") else ".." diff --git a/src/fimdlp/tests/FImdlp_test.py b/src/fimdlp/tests/FImdlp_test.py index c6d01f0..df8da03 100644 --- a/src/fimdlp/tests/FImdlp_test.py +++ b/src/fimdlp/tests/FImdlp_test.py @@ -14,13 +14,13 @@ class FImdlpTest(unittest.TestCase): def test_init(self): clf = FImdlp() self.assertEqual(-1, clf.n_jobs) - self.assertFalse(clf.proposal) - clf = FImdlp(proposal=True, n_jobs=7) - self.assertTrue(clf.proposal) + self.assertEqual(0, clf.proposal) + clf = FImdlp(proposal=1, n_jobs=7) + self.assertEqual(1, clf.proposal) self.assertEqual(7, clf.n_jobs) def test_fit_proposal(self): - clf = FImdlp(proposal=True) + clf = FImdlp(proposal=1) clf.fit([[1, 2], [3, 4]], [1, 2]) self.assertEqual(clf.n_features_, 2) self.assertListEqual(clf.X_.tolist(), [[1, 2], [3, 4]]) @@ -49,7 +49,7 @@ class FImdlpTest(unittest.TestCase): self.assertListEqual([0, 2, 3], clf.features_) def test_fit_original(self): - clf = FImdlp(proposal=False) + clf = FImdlp(proposal=0) clf.fit([[1, 2], [3, 4]], [1, 2]) self.assertEqual(clf.n_features_, 2) self.assertListEqual(clf.X_.tolist(), [[1, 2], [3, 4]]) @@ -94,7 +94,7 @@ class FImdlpTest(unittest.TestCase): self.assertListEqual(res.tolist(), [[0, 2], [0, 4]]) def test_transform_original(self): - clf = FImdlp(proposal=False) + clf = FImdlp(proposal=0) clf.fit([[1, 2], [3, 4]], [1, 2]) self.assertEqual( clf.transform([[1, 2], [3, 4]]).tolist(), [[0, 0], [0, 0]] @@ -120,11 +120,11 @@ class FImdlpTest(unittest.TestCase): with self.assertRaises(ValueError): clf.transform([[1, 2, 3], [4, 5, 6]]) with self.assertRaises(sklearn.exceptions.NotFittedError): - clf = FImdlp(proposal=False) + clf = FImdlp(proposal=0) clf.transform([[1, 2], [3, 4]]) def test_transform_proposal(self): - clf = FImdlp(proposal=True) + clf = FImdlp(proposal=1) clf.fit([[1, 2], [3, 4]], [1, 2]) self.assertEqual( clf.transform([[1, 2], [3, 4]]).tolist(), [[0, 0], [0, 0]] @@ -150,5 +150,5 @@ class FImdlpTest(unittest.TestCase): with self.assertRaises(ValueError): clf.transform([[1, 2, 3], [4, 5, 6]]) with self.assertRaises(sklearn.exceptions.NotFittedError): - clf = FImdlp(proposal=True) + clf = FImdlp(proposal=1) clf.transform([[1, 2], [3, 4]])