Implement splitter type mutual info

This commit is contained in:
2021-05-01 23:38:34 +02:00
parent 28c7558f01
commit 5cef0f4875
2 changed files with 22 additions and 10 deletions

View File

@@ -11,7 +11,7 @@ from typing import Optional
import numpy as np import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC, LinearSVC from sklearn.svm import SVC, LinearSVC
from sklearn.feature_selection import SelectKBest from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.multiclass import check_classification_targets
from sklearn.exceptions import ConvergenceWarning from sklearn.exceptions import ConvergenceWarning
@@ -205,9 +205,9 @@ class Splitter:
f"criteria has to be max_samples or impurity; got ({criteria})" f"criteria has to be max_samples or impurity; got ({criteria})"
) )
if feature_select not in ["random", "best"]: if feature_select not in ["random", "best", "mutual"]:
raise ValueError( raise ValueError(
"splitter must be either random or best, got " "splitter must be in {random, best, mutual} got "
f"({feature_select})" f"({feature_select})"
) )
self.criterion_function = getattr(self, f"_{self._criterion}") self.criterion_function = getattr(self, f"_{self._criterion}")
@@ -381,11 +381,19 @@ class Splitter:
dataset.shape[1], max_features dataset.shape[1], max_features
) )
return self._select_best_set(dataset, labels, features_sets) return self._select_best_set(dataset, labels, features_sets)
# Take KBest features if self._feature_select == "best":
return ( # Take KBest features
SelectKBest(k=max_features) return (
.fit(dataset, labels) SelectKBest(k=max_features)
.get_support(indices=True) .fit(dataset, labels)
.get_support(indices=True)
)
# return best features with mutual info with the label
feature_list = mutual_info_classif(dataset, labels)
return tuple(
sorted(
range(len(feature_list)), key=lambda sub: feature_list[sub]
)[-max_features:]
) )
def get_subspace( def get_subspace(

View File

@@ -195,10 +195,14 @@ class Splitter_test(unittest.TestCase):
[0, 3, 7, 12], # random entropy impurity [0, 3, 7, 12], # random entropy impurity
[1, 7, 9, 12], # random gini max_samples [1, 7, 9, 12], # random gini max_samples
[1, 5, 8, 12], # random gini impurity [1, 5, 8, 12], # random gini impurity
[6, 9, 11, 12], # mutual entropy max_samples
[6, 9, 11, 12], # mutual entropy impurity
[6, 9, 11, 12], # mutual gini max_samples
[6, 9, 11, 12], # mutual gini impurity
] ]
X, y = load_wine(return_X_y=True) X, y = load_wine(return_X_y=True)
rn = 0 rn = 0
for feature_select in ["best", "random"]: for feature_select in ["best", "random", "mutual"]:
for criterion in ["entropy", "gini"]: for criterion in ["entropy", "gini"]:
for criteria in [ for criteria in [
"max_samples", "max_samples",
@@ -221,7 +225,7 @@ class Splitter_test(unittest.TestCase):
# criteria, # criteria,
# ) # )
# ) # )
self.assertListEqual(expected, list(computed)) self.assertListEqual(expected, sorted(list(computed)))
self.assertListEqual( self.assertListEqual(
X[:, computed].tolist(), dataset.tolist() X[:, computed].tolist(), dataset.tolist()
) )