mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 08:26:00 +00:00
Implement splitter type mutual info
This commit is contained in:
@@ -11,7 +11,7 @@ from typing import Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||||
from sklearn.svm import SVC, LinearSVC
|
from sklearn.svm import SVC, LinearSVC
|
||||||
from sklearn.feature_selection import SelectKBest
|
from sklearn.feature_selection import SelectKBest, mutual_info_classif
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from sklearn.utils.multiclass import check_classification_targets
|
from sklearn.utils.multiclass import check_classification_targets
|
||||||
from sklearn.exceptions import ConvergenceWarning
|
from sklearn.exceptions import ConvergenceWarning
|
||||||
@@ -205,9 +205,9 @@ class Splitter:
|
|||||||
f"criteria has to be max_samples or impurity; got ({criteria})"
|
f"criteria has to be max_samples or impurity; got ({criteria})"
|
||||||
)
|
)
|
||||||
|
|
||||||
if feature_select not in ["random", "best"]:
|
if feature_select not in ["random", "best", "mutual"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"splitter must be either random or best, got "
|
"splitter must be in {random, best, mutual} got "
|
||||||
f"({feature_select})"
|
f"({feature_select})"
|
||||||
)
|
)
|
||||||
self.criterion_function = getattr(self, f"_{self._criterion}")
|
self.criterion_function = getattr(self, f"_{self._criterion}")
|
||||||
@@ -381,11 +381,19 @@ class Splitter:
|
|||||||
dataset.shape[1], max_features
|
dataset.shape[1], max_features
|
||||||
)
|
)
|
||||||
return self._select_best_set(dataset, labels, features_sets)
|
return self._select_best_set(dataset, labels, features_sets)
|
||||||
# Take KBest features
|
if self._feature_select == "best":
|
||||||
return (
|
# Take KBest features
|
||||||
SelectKBest(k=max_features)
|
return (
|
||||||
.fit(dataset, labels)
|
SelectKBest(k=max_features)
|
||||||
.get_support(indices=True)
|
.fit(dataset, labels)
|
||||||
|
.get_support(indices=True)
|
||||||
|
)
|
||||||
|
# return best features with mutual info with the label
|
||||||
|
feature_list = mutual_info_classif(dataset, labels)
|
||||||
|
return tuple(
|
||||||
|
sorted(
|
||||||
|
range(len(feature_list)), key=lambda sub: feature_list[sub]
|
||||||
|
)[-max_features:]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_subspace(
|
def get_subspace(
|
||||||
|
@@ -195,10 +195,14 @@ class Splitter_test(unittest.TestCase):
|
|||||||
[0, 3, 7, 12], # random entropy impurity
|
[0, 3, 7, 12], # random entropy impurity
|
||||||
[1, 7, 9, 12], # random gini max_samples
|
[1, 7, 9, 12], # random gini max_samples
|
||||||
[1, 5, 8, 12], # random gini impurity
|
[1, 5, 8, 12], # random gini impurity
|
||||||
|
[6, 9, 11, 12], # mutual entropy max_samples
|
||||||
|
[6, 9, 11, 12], # mutual entropy impurity
|
||||||
|
[6, 9, 11, 12], # mutual gini max_samples
|
||||||
|
[6, 9, 11, 12], # mutual gini impurity
|
||||||
]
|
]
|
||||||
X, y = load_wine(return_X_y=True)
|
X, y = load_wine(return_X_y=True)
|
||||||
rn = 0
|
rn = 0
|
||||||
for feature_select in ["best", "random"]:
|
for feature_select in ["best", "random", "mutual"]:
|
||||||
for criterion in ["entropy", "gini"]:
|
for criterion in ["entropy", "gini"]:
|
||||||
for criteria in [
|
for criteria in [
|
||||||
"max_samples",
|
"max_samples",
|
||||||
@@ -221,7 +225,7 @@ class Splitter_test(unittest.TestCase):
|
|||||||
# criteria,
|
# criteria,
|
||||||
# )
|
# )
|
||||||
# )
|
# )
|
||||||
self.assertListEqual(expected, list(computed))
|
self.assertListEqual(expected, sorted(list(computed)))
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
X[:, computed].tolist(), dataset.tolist()
|
X[:, computed].tolist(), dataset.tolist()
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user