mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 17:06:01 +00:00
Implement splitter type mutual info
This commit is contained in:
@@ -11,7 +11,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.svm import SVC, LinearSVC
|
||||
from sklearn.feature_selection import SelectKBest
|
||||
from sklearn.feature_selection import SelectKBest, mutual_info_classif
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.utils.multiclass import check_classification_targets
|
||||
from sklearn.exceptions import ConvergenceWarning
|
||||
@@ -205,9 +205,9 @@ class Splitter:
|
||||
f"criteria has to be max_samples or impurity; got ({criteria})"
|
||||
)
|
||||
|
||||
if feature_select not in ["random", "best"]:
|
||||
if feature_select not in ["random", "best", "mutual"]:
|
||||
raise ValueError(
|
||||
"splitter must be either random or best, got "
|
||||
"splitter must be in {random, best, mutual} got "
|
||||
f"({feature_select})"
|
||||
)
|
||||
self.criterion_function = getattr(self, f"_{self._criterion}")
|
||||
@@ -381,11 +381,19 @@ class Splitter:
|
||||
dataset.shape[1], max_features
|
||||
)
|
||||
return self._select_best_set(dataset, labels, features_sets)
|
||||
# Take KBest features
|
||||
return (
|
||||
SelectKBest(k=max_features)
|
||||
.fit(dataset, labels)
|
||||
.get_support(indices=True)
|
||||
if self._feature_select == "best":
|
||||
# Take KBest features
|
||||
return (
|
||||
SelectKBest(k=max_features)
|
||||
.fit(dataset, labels)
|
||||
.get_support(indices=True)
|
||||
)
|
||||
# return best features with mutual info with the label
|
||||
feature_list = mutual_info_classif(dataset, labels)
|
||||
return tuple(
|
||||
sorted(
|
||||
range(len(feature_list)), key=lambda sub: feature_list[sub]
|
||||
)[-max_features:]
|
||||
)
|
||||
|
||||
def get_subspace(
|
||||
|
Reference in New Issue
Block a user