Implement iwss feature selection (#45) (#47)

This commit is contained in:
Ricardo Montañana Gómez
2021-10-29 11:49:46 +02:00
committed by GitHub
parent 36ff3da26d
commit 36b08b1bcf
4 changed files with 81 additions and 39 deletions

View File

@@ -271,10 +271,17 @@ class Splitter:
f"criteria has to be max_samples or impurity; got ({criteria})"
)
if feature_select not in ["random", "best", "mutual", "cfs", "fcbf"]:
if feature_select not in [
"random",
"best",
"mutual",
"cfs",
"fcbf",
"iwss",
]:
raise ValueError(
"splitter must be in {random, best, mutual, cfs, fcbf} got "
f"({feature_select})"
"splitter must be in {random, best, mutual, cfs, fcbf, iwss} "
f"got ({feature_select})"
)
self.criterion_function = getattr(self, f"_{self._criterion}")
self.decision_criteria = getattr(self, f"_{self._criteria}")
@@ -409,6 +416,31 @@ class Splitter:
mufs = MUFS(max_features=max_features, discrete=False)
return mufs.fcbf(dataset, labels, 5e-4).get_results()
@staticmethod
def _fs_iwss(
dataset: np.array, labels: np.array, max_features: int
) -> tuple:
"""Correlattion-based feature selection based on iwss with max_features
limit
Parameters
----------
dataset : np.array
array of samples
labels : np.array
labels of the dataset
max_features : int
number of features of the subspace
(< number of features in dataset)
Returns
-------
tuple
indices of the features selected
"""
mufs = MUFS(max_features=max_features, discrete=False)
return mufs.iwss(dataset, labels, 0.25).get_results()
def partition_impurity(self, y: np.array) -> np.array:
return self.criterion_function(y)

View File

@@ -285,3 +285,15 @@ class Splitter_test(unittest.TestCase):
Xs, computed = tcl.get_subspace(X, y, rs)
self.assertListEqual(expected, list(computed))
self.assertListEqual(X[:, expected].tolist(), Xs.tolist())
def test_get_iwss_subspaces(self):
results = [
(4, [1, 5, 9, 12]),
(6, [1, 5, 9, 12, 4, 15]),
]
for rs, expected in results:
X, y = load_dataset(n_features=20, n_informative=7)
tcl = self.build(feature_select="iwss", random_state=rs)
Xs, computed = tcl.get_subspace(X, y, rs)
self.assertListEqual(expected, list(computed))
self.assertListEqual(X[:, expected].tolist(), Xs.tolist())