From 5cef0f48753dde283d717db6c01c4edd9343fa2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sat, 1 May 2021 23:38:34 +0200 Subject: [PATCH] Implement splitter type mutual info --- stree/Strees.py | 24 ++++++++++++++++-------- stree/tests/Splitter_test.py | 8 ++++++-- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/stree/Strees.py b/stree/Strees.py index 3062364..15c794e 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -11,7 +11,7 @@ from typing import Optional import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.svm import SVC, LinearSVC -from sklearn.feature_selection import SelectKBest +from sklearn.feature_selection import SelectKBest, mutual_info_classif from sklearn.preprocessing import StandardScaler from sklearn.utils.multiclass import check_classification_targets from sklearn.exceptions import ConvergenceWarning @@ -205,9 +205,9 @@ class Splitter: f"criteria has to be max_samples or impurity; got ({criteria})" ) - if feature_select not in ["random", "best"]: + if feature_select not in ["random", "best", "mutual"]: raise ValueError( - "splitter must be either random or best, got " + "splitter must be in {random, best, mutual} got " f"({feature_select})" ) self.criterion_function = getattr(self, f"_{self._criterion}") @@ -381,11 +381,19 @@ class Splitter: dataset.shape[1], max_features ) return self._select_best_set(dataset, labels, features_sets) - # Take KBest features - return ( - SelectKBest(k=max_features) - .fit(dataset, labels) - .get_support(indices=True) + if self._feature_select == "best": + # Take KBest features + return ( + SelectKBest(k=max_features) + .fit(dataset, labels) + .get_support(indices=True) + ) + # return best features with mutual info with the label + feature_list = mutual_info_classif(dataset, labels) + return tuple( + sorted( + range(len(feature_list)), key=lambda sub: feature_list[sub] + )[-max_features:] ) def get_subspace( diff --git a/stree/tests/Splitter_test.py b/stree/tests/Splitter_test.py index 3e45f29..cee4df5 100644 --- a/stree/tests/Splitter_test.py +++ b/stree/tests/Splitter_test.py @@ -195,10 +195,14 @@ class Splitter_test(unittest.TestCase): [0, 3, 7, 12], # random entropy impurity [1, 7, 9, 12], # random gini max_samples [1, 5, 8, 12], # random gini impurity + [6, 9, 11, 12], # mutual entropy max_samples + [6, 9, 11, 12], # mutual entropy impurity + [6, 9, 11, 12], # mutual gini max_samples + [6, 9, 11, 12], # mutual gini impurity ] X, y = load_wine(return_X_y=True) rn = 0 - for feature_select in ["best", "random"]: + for feature_select in ["best", "random", "mutual"]: for criterion in ["entropy", "gini"]: for criteria in [ "max_samples", @@ -221,7 +225,7 @@ class Splitter_test(unittest.TestCase): # criteria, # ) # ) - self.assertListEqual(expected, list(computed)) + self.assertListEqual(expected, sorted(list(computed))) self.assertListEqual( X[:, computed].tolist(), dataset.tolist() )