From 39fbdf73a7f2365b86a57812e84ea3d3a5ae472d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Sat, 29 May 2021 01:24:37 +0200 Subject: [PATCH] Add max_features to MFS to help STree integration --- mfs/Selection.py | 27 ++++++++++---- mfs/tests/MFS_test.py | 87 ++++++++++++++++++++++++++++--------------- 2 files changed, 75 insertions(+), 39 deletions(-) diff --git a/mfs/Selection.py b/mfs/Selection.py index d971131..574f8fa 100755 --- a/mfs/Selection.py +++ b/mfs/Selection.py @@ -109,10 +109,16 @@ class MFS: Correlated Feature Selection as in "Correlation-based Feature Selection for Machine Learning" by Mark A. Hall + + Parameters + ---------- + max_features: int + The maximum number of features to return """ - def __init__(self): + def __init__(self, max_features): self._initialize() + self._max_features = max_features def _initialize(self): """Initialize the attributes so support multiple calls using same @@ -180,8 +186,8 @@ class MFS: """ # lgtm has already recognized that this is a false positive rcf = self._su_labels[ - features - ].sum() # lgtm [py/hash-unhashable-value] + features # lgtm [py/hash-unhashable-value] + ].sum() rff = 0.0 k = len(features) for pair in list(combinations(features, 2)): @@ -229,7 +235,10 @@ class MFS: candidates.append(feature_order[id_selected]) self._scores.append(merit) del feature_order[id_selected] - if len(feature_order) == 0: + if ( + len(feature_order) == 0 + or len(candidates) == self._max_features + ): # Force leaving the loop continue_condition = False if len(self._scores) >= 5: @@ -253,7 +262,7 @@ class MFS: self._result = candidates return self - def fcbs(self, X, y, threshold): + def fcbf(self, X, y, threshold): """Fast Correlation-Based Filter Parameters @@ -273,10 +282,10 @@ class MFS: Raises ------ ValueError - if the threshold is less than a selected value of 1e-4 + if the threshold is less than a selected value of 1e-7 """ - if threshold < 1e-4: - raise ValueError("Threshold cannot be less than 1e-4") + if threshold < 1e-7: + raise ValueError("Threshold cannot be less than 1e-7") self._initialize() self.X_ = X self.y_ = y @@ -301,6 +310,8 @@ class MFS: s_list[index_q] = 0.0 self._result.append(index_p) self._scores.append(s_list[index_p]) + if len(self._result) == self._max_features: + break return self def get_results(self): diff --git a/mfs/tests/MFS_test.py b/mfs/tests/MFS_test.py index da36ca2..db9aebf 100755 --- a/mfs/tests/MFS_test.py +++ b/mfs/tests/MFS_test.py @@ -21,8 +21,8 @@ class MFS_test(unittest.TestCase): self.assertAlmostEqual(a, b, tol) def test_initialize(self): - mfs = MFS() - mfs.fcbs(self.X_w, self.y_w, 0.05) + mfs = MFS(max_features=100) + mfs.fcbf(self.X_w, self.y_w, 0.05) mfs._initialize() self.assertIsNone(mfs.get_results()) self.assertListEqual([], mfs.get_scores()) @@ -30,7 +30,7 @@ class MFS_test(unittest.TestCase): self.assertIsNone(mfs._su_labels) def test_csf_wine(self): - mfs = MFS() + mfs = MFS(max_features=100) expected = [6, 12, 9, 4, 10, 0] self.assertListAlmostEqual( expected, mfs.cfs(self.X_w, self.y_w).get_results() @@ -45,8 +45,21 @@ class MFS_test(unittest.TestCase): ] self.assertListAlmostEqual(expected, mfs.get_scores()) + def test_csf_max_features(self): + mfs = MFS(max_features=3) + expected = [6, 12, 9] + self.assertListAlmostEqual( + expected, mfs.cfs(self.X_w, self.y_w).get_results() + ) + expected = [ + 0.5218299405215557, + 0.602513857132804, + 0.4877384978817362, + ] + self.assertListAlmostEqual(expected, mfs.get_scores()) + def test_csf_iris(self): - mfs = MFS() + mfs = MFS(max_features=100) expected = [3, 2, 0, 1] computed = mfs.cfs(self.X_i, self.y_i).get_results() self.assertListAlmostEqual(expected, computed) @@ -58,9 +71,9 @@ class MFS_test(unittest.TestCase): ] self.assertListAlmostEqual(expected, mfs.get_scores()) - def test_fcbs_wine(self): - mfs = MFS() - computed = mfs.fcbs(self.X_w, self.y_w, threshold=0.05).get_results() + def test_fcbf_wine(self): + mfs = MFS(max_features=100) + computed = mfs.fcbf(self.X_w, self.y_w, threshold=0.05).get_results() expected = [6, 9, 12, 0, 11, 4] self.assertListAlmostEqual(expected, computed) expected = [ @@ -73,30 +86,42 @@ class MFS_test(unittest.TestCase): ] self.assertListAlmostEqual(expected, mfs.get_scores()) - def test_fcbs_iris(self): - mfs = MFS() - computed = mfs.fcbs(self.X_i, self.y_i, threshold=0.05).get_results() - expected = [3, 2] - self.assertListAlmostEqual(expected, computed) - expected = [0.870521418179061, 0.810724587460511] - self.assertListAlmostEqual(expected, mfs.get_scores()) - - def test_compute_su_labels(self): - mfs = MFS() - mfs.fcbs(self.X_i, self.y_i, threshold=0.05) - expected = [0.0, 0.0, 0.810724587460511, 0.870521418179061] - self.assertListAlmostEqual(expected, mfs._compute_su_labels().tolist()) - mfs._su_labels = [1, 2, 3, 4] - self.assertListAlmostEqual([1, 2, 3, 4], mfs._compute_su_labels()) - - def test_invalid_threshold(self): - mfs = MFS() - with self.assertRaises(ValueError): - mfs.fcbs(self.X_i, self.y_i, threshold=1e-5) - - def test_fcbs_exit_threshold(self): - mfs = MFS() - computed = mfs.fcbs(self.X_w, self.y_w, threshold=0.4).get_results() + def test_fcbf_max_features(self): + mfs = MFS(max_features=3) + computed = mfs.fcbf(self.X_w, self.y_w, threshold=0.05).get_results() + expected = [6, 9, 12] + self.assertListAlmostEqual(expected, computed) + expected = [ + 0.5218299405215557, + 0.46224298637417455, + 0.44518278979085646, + ] + self.assertListAlmostEqual(expected, mfs.get_scores()) + + def test_fcbf_iris(self): + mfs = MFS(max_features=100) + computed = mfs.fcbf(self.X_i, self.y_i, threshold=0.05).get_results() + expected = [3, 2] + self.assertListAlmostEqual(expected, computed) + expected = [0.870521418179061, 0.810724587460511] + self.assertListAlmostEqual(expected, mfs.get_scores()) + + def test_compute_su_labels(self): + mfs = MFS(max_features=100) + mfs.fcbf(self.X_i, self.y_i, threshold=0.05) + expected = [0.0, 0.0, 0.810724587460511, 0.870521418179061] + self.assertListAlmostEqual(expected, mfs._compute_su_labels().tolist()) + mfs._su_labels = [1, 2, 3, 4] + self.assertListAlmostEqual([1, 2, 3, 4], mfs._compute_su_labels()) + + def test_invalid_threshold(self): + mfs = MFS(max_features=100) + with self.assertRaises(ValueError): + mfs.fcbf(self.X_i, self.y_i, threshold=1e-15) + + def test_fcbf_exit_threshold(self): + mfs = MFS(max_features=100) + computed = mfs.fcbf(self.X_w, self.y_w, threshold=0.4).get_results() expected = [6, 9, 12] self.assertListAlmostEqual(expected, computed) expected = [