Implement splitter type mutual info

This commit is contained in:
2021-05-01 23:38:34 +02:00
parent 28c7558f01
commit 5cef0f4875
2 changed files with 22 additions and 10 deletions

View File

@@ -11,7 +11,7 @@ from typing import Optional
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC, LinearSVC
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import SelectKBest, mutual_info_classif
from sklearn.preprocessing import StandardScaler
from sklearn.utils.multiclass import check_classification_targets
from sklearn.exceptions import ConvergenceWarning
@@ -205,9 +205,9 @@ class Splitter:
f"criteria has to be max_samples or impurity; got ({criteria})"
)
if feature_select not in ["random", "best"]:
if feature_select not in ["random", "best", "mutual"]:
raise ValueError(
"splitter must be either random or best, got "
"splitter must be in {random, best, mutual} got "
f"({feature_select})"
)
self.criterion_function = getattr(self, f"_{self._criterion}")
@@ -381,11 +381,19 @@ class Splitter:
dataset.shape[1], max_features
)
return self._select_best_set(dataset, labels, features_sets)
# Take KBest features
return (
SelectKBest(k=max_features)
.fit(dataset, labels)
.get_support(indices=True)
if self._feature_select == "best":
# Take KBest features
return (
SelectKBest(k=max_features)
.fit(dataset, labels)
.get_support(indices=True)
)
# return best features with mutual info with the label
feature_list = mutual_info_classif(dataset, labels)
return tuple(
sorted(
range(len(feature_list)), key=lambda sub: feature_list[sub]
)[-max_features:]
)
def get_subspace(