Implement splitter type mutual info

This commit is contained in:
2021-05-01 23:38:34 +02:00
parent 28c7558f01
commit 5cef0f4875
2 changed files with 22 additions and 10 deletions

View File

@@ -11,7 +11,7 @@ from typing import Optional
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC, LinearSVC
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.preprocessing import StandardScaler
from sklearn.utils.multiclass import check_classification_targets
from sklearn.exceptions import ConvergenceWarning
@@ -205,9 +205,9 @@ class Splitter:
f"criteria has to be max_samples or impurity; got ({criteria})"
)
if feature_select not in ["random", "best"]:
if feature_select not in ["random", "best", "mutual"]:
raise ValueError(
"splitter must be either random or best, got "
"splitter must be in {random, best, mutual} got "
f"({feature_select})"
)
self.criterion_function = getattr(self, f"_{self._criterion}")
@@ -381,11 +381,19 @@ class Splitter:
dataset.shape[1], max_features
)
return self._select_best_set(dataset, labels, features_sets)
# Take KBest features
return (
SelectKBest(k=max_features)
.fit(dataset, labels)
.get_support(indices=True)
if self._feature_select == "best":
# Take KBest features
return (
SelectKBest(k=max_features)
.fit(dataset, labels)
.get_support(indices=True)
)
# return best features with mutual info with the label
feature_list = mutual_info_classif(dataset, labels)
return tuple(
sorted(
range(len(feature_list)), key=lambda sub: feature_list[sub]
)[-max_features:]
)
def get_subspace(

View File

@@ -195,10 +195,14 @@ class Splitter_test(unittest.TestCase):
[0, 3, 7, 12], # random entropy impurity
[1, 7, 9, 12], # random gini max_samples
[1, 5, 8, 12], # random gini impurity
[6, 9, 11, 12], # mutual entropy max_samples
[6, 9, 11, 12], # mutual entropy impurity
[6, 9, 11, 12], # mutual gini max_samples
[6, 9, 11, 12], # mutual gini impurity
]
X, y = load_wine(return_X_y=True)
rn = 0
for feature_select in ["best", "random"]:
for feature_select in ["best", "random", "mutual"]:
for criterion in ["entropy", "gini"]:
for criteria in [
"max_samples",
@@ -221,7 +225,7 @@ class Splitter_test(unittest.TestCase):
# criteria,
# )
# )
self.assertListEqual(expected, list(computed))
self.assertListEqual(expected, sorted(list(computed)))
self.assertListEqual(
X[:, computed].tolist(), dataset.tolist()
)