(#46) Implement true random feature selection (#48)

* (#46) Implement true random feature selection
This commit is contained in:
Ricardo Montañana Gómez
2021-10-29 12:59:03 +02:00
committed by GitHub
parent 36b08b1bcf
commit bf678df159
2 changed files with 41 additions and 1 deletions

View File

@@ -273,6 +273,7 @@ class Splitter:
if feature_select not in [
"random",
"trandom",
"best",
"mutual",
"cfs",
@@ -280,7 +281,8 @@ class Splitter:
"iwss",
]:
raise ValueError(
"splitter must be in {random, best, mutual, cfs, fcbf, iwss} "
"splitter must be in {random, trandom, best, mutual, cfs, "
"fcbf, iwss} "
f"got ({feature_select})"
)
self.criterion_function = getattr(self, f"_{self._criterion}")
@@ -312,6 +314,31 @@ class Splitter:
features_sets = self._generate_spaces(n_features, max_features)
return self._select_best_set(dataset, labels, features_sets)
@staticmethod
def _fs_trandom(
dataset: np.array, labels: np.array, max_features: int
) -> tuple:
"""Return the a random feature set combination
Parameters
----------
dataset : np.array
array of samples
labels : np.array
labels of the dataset
max_features : int
number of features of the subspace
(< number of features in dataset)
Returns
-------
tuple
indices of the features selected
"""
# Random feature reduction
n_features = dataset.shape[1]
return tuple(sorted(random.sample(range(n_features), max_features)))
@staticmethod
def _fs_best(
dataset: np.array, labels: np.array, max_features: int

View File

@@ -297,3 +297,16 @@ class Splitter_test(unittest.TestCase):
Xs, computed = tcl.get_subspace(X, y, rs)
self.assertListEqual(expected, list(computed))
self.assertListEqual(X[:, expected].tolist(), Xs.tolist())
def test_get_trandom_subspaces(self):
results = [
(4, [3, 7, 9, 12]),
(6, [0, 1, 2, 8, 15, 18]),
(7, [1, 2, 4, 8, 10, 12, 13]),
]
for rs, expected in results:
X, y = load_dataset(n_features=20, n_informative=7)
tcl = self.build(feature_select="trandom", random_state=rs)
Xs, computed = tcl.get_subspace(X, y, rs)
self.assertListEqual(expected, list(computed))
self.assertListEqual(X[:, expected].tolist(), Xs.tolist())