mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
Implement splitter type mutual info
This commit is contained in:
@@ -11,7 +11,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.svm import SVC, LinearSVC
|
||||
from sklearn.feature_selection import SelectKBest
|
||||
from sklearn.feature_selection import SelectKBest, mutual_info_classif
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.utils.multiclass import check_classification_targets
|
||||
from sklearn.exceptions import ConvergenceWarning
|
||||
@@ -205,9 +205,9 @@ class Splitter:
|
||||
f"criteria has to be max_samples or impurity; got ({criteria})"
|
||||
)
|
||||
|
||||
if feature_select not in ["random", "best"]:
|
||||
if feature_select not in ["random", "best", "mutual"]:
|
||||
raise ValueError(
|
||||
"splitter must be either random or best, got "
|
||||
"splitter must be in {random, best, mutual} got "
|
||||
f"({feature_select})"
|
||||
)
|
||||
self.criterion_function = getattr(self, f"_{self._criterion}")
|
||||
@@ -381,11 +381,19 @@ class Splitter:
|
||||
dataset.shape[1], max_features
|
||||
)
|
||||
return self._select_best_set(dataset, labels, features_sets)
|
||||
# Take KBest features
|
||||
return (
|
||||
SelectKBest(k=max_features)
|
||||
.fit(dataset, labels)
|
||||
.get_support(indices=True)
|
||||
if self._feature_select == "best":
|
||||
# Take KBest features
|
||||
return (
|
||||
SelectKBest(k=max_features)
|
||||
.fit(dataset, labels)
|
||||
.get_support(indices=True)
|
||||
)
|
||||
# return best features with mutual info with the label
|
||||
feature_list = mutual_info_classif(dataset, labels)
|
||||
return tuple(
|
||||
sorted(
|
||||
range(len(feature_list)), key=lambda sub: feature_list[sub]
|
||||
)[-max_features:]
|
||||
)
|
||||
|
||||
def get_subspace(
|
||||
|
@@ -195,10 +195,14 @@ class Splitter_test(unittest.TestCase):
|
||||
[0, 3, 7, 12], # random entropy impurity
|
||||
[1, 7, 9, 12], # random gini max_samples
|
||||
[1, 5, 8, 12], # random gini impurity
|
||||
[6, 9, 11, 12], # mutual entropy max_samples
|
||||
[6, 9, 11, 12], # mutual entropy impurity
|
||||
[6, 9, 11, 12], # mutual gini max_samples
|
||||
[6, 9, 11, 12], # mutual gini impurity
|
||||
]
|
||||
X, y = load_wine(return_X_y=True)
|
||||
rn = 0
|
||||
for feature_select in ["best", "random"]:
|
||||
for feature_select in ["best", "random", "mutual"]:
|
||||
for criterion in ["entropy", "gini"]:
|
||||
for criteria in [
|
||||
"max_samples",
|
||||
@@ -221,7 +225,7 @@ class Splitter_test(unittest.TestCase):
|
||||
# criteria,
|
||||
# )
|
||||
# )
|
||||
self.assertListEqual(expected, list(computed))
|
||||
self.assertListEqual(expected, sorted(list(computed)))
|
||||
self.assertListEqual(
|
||||
X[:, computed].tolist(), dataset.tolist()
|
||||
)
|
||||
|
Reference in New Issue
Block a user