Add max_features to MFS to help STree integration

This commit is contained in:
2021-05-29 01:24:37 +02:00
parent ee5020e6d9
commit 39fbdf73a7
2 changed files with 75 additions and 39 deletions

View File

@@ -109,10 +109,16 @@ class MFS:
Correlated Feature Selection as in "Correlation-based Feature Selection for
Machine Learning" by Mark A. Hall
Parameters
----------
max_features: int
The maximum number of features to return
"""
def __init__(self):
def __init__(self, max_features):
self._initialize()
self._max_features = max_features
def _initialize(self):
"""Initialize the attributes so support multiple calls using same
@@ -180,8 +186,8 @@ class MFS:
"""
# lgtm has already recognized that this is a false positive
rcf = self._su_labels[
features
].sum() # lgtm [py/hash-unhashable-value]
features # lgtm [py/hash-unhashable-value]
].sum()
rff = 0.0
k = len(features)
for pair in list(combinations(features, 2)):
@@ -229,7 +235,10 @@ class MFS:
candidates.append(feature_order[id_selected])
self._scores.append(merit)
del feature_order[id_selected]
if len(feature_order) == 0:
if (
len(feature_order) == 0
or len(candidates) == self._max_features
):
# Force leaving the loop
continue_condition = False
if len(self._scores) >= 5:
@@ -253,7 +262,7 @@ class MFS:
self._result = candidates
return self
def fcbs(self, X, y, threshold):
def fcbf(self, X, y, threshold):
"""Fast Correlation-Based Filter
Parameters
@@ -273,10 +282,10 @@ class MFS:
Raises
------
ValueError
if the threshold is less than a selected value of 1e-4
if the threshold is less than a selected value of 1e-7
"""
if threshold < 1e-4:
raise ValueError("Threshold cannot be less than 1e-4")
if threshold < 1e-7:
raise ValueError("Threshold cannot be less than 1e-7")
self._initialize()
self.X_ = X
self.y_ = y
@@ -301,6 +310,8 @@ class MFS:
s_list[index_q] = 0.0
self._result.append(index_p)
self._scores.append(s_list[index_p])
if len(self._result) == self._max_features:
break
return self
def get_results(self):