Remove itertools combinations from subspaces

This commit is contained in:
2020-12-10 14:14:42 +01:00
parent 475ad7e752
commit 3f01234ebf
3 changed files with 21 additions and 15 deletions

View File

@@ -11,7 +11,6 @@ import numbers
import random import random
import warnings import warnings
from math import log from math import log
from itertools import combinations
import numpy as np import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC, LinearSVC from sklearn.svm import SVC, LinearSVC
@@ -253,19 +252,26 @@ class Splitter:
selected = feature_set selected = feature_set
return selected if selected is not None else feature_set return selected if selected is not None else feature_set
@staticmethod
def _generate_spaces(features: int, max_features: int) -> list:
comb = set()
# Generate at most 3 combinations
set_length = 1 if max_features == features else 3
while len(comb) < set_length:
comb.add(
tuple(sorted(random.sample(range(features), max_features)))
)
return list(comb)
def _get_subspaces_set( def _get_subspaces_set(
self, dataset: np.array, labels: np.array, max_features: int self, dataset: np.array, labels: np.array, max_features: int
) -> np.array: ) -> np.array:
features = range(dataset.shape[1]) features_sets = self._generate_spaces(dataset.shape[1], max_features)
features_sets = list(combinations(features, max_features))
if len(features_sets) > 1: if len(features_sets) > 1:
if self._splitter_type == "random": if self._splitter_type == "random":
index = random.randint(0, len(features_sets) - 1) index = random.randint(0, len(features_sets) - 1)
return features_sets[index] return features_sets[index]
else: else:
# get only 3 sets at most
if len(features_sets) > 3:
features_sets = random.sample(features_sets, 3)
return self._select_best_set(dataset, labels, features_sets) return self._select_best_set(dataset, labels, features_sets)
else: else:
return features_sets[0] return features_sets[0]

View File

@@ -176,14 +176,14 @@ class Splitter_test(unittest.TestCase):
def test_splitter_parameter(self): def test_splitter_parameter(self):
expected_values = [ expected_values = [
[0, 1, 7, 9], # best entropy max_samples [0, 4, 6, 12], # best entropy max_samples
[3, 8, 10, 11], # best entropy impurity [1, 3, 6, 10], # best entropy impurity
[0, 2, 8, 12], # best gini max_samples [0, 1, 5, 11], # best gini max_samples
[1, 2, 5, 12], # best gini impurity [0, 1, 7, 9], # best gini impurity
[1, 2, 5, 10], # random entropy max_samples [0, 4, 6, 8], # random entropy max_samples
[4, 8, 9, 12], # random entropy impurity [4, 5, 8, 9], # random entropy impurity
[3, 9, 11, 12], # random gini max_samples [0, 4, 10, 12], # random gini max_samples
[1, 5, 6, 9], # random gini impurity [1, 5, 8, 12], # random gini impurity
] ]
X, y = load_wine(return_X_y=True) X, y = load_wine(return_X_y=True)
rn = 0 rn = 0

View File

@@ -313,7 +313,7 @@ class Stree_test(unittest.TestCase):
X, y = load_dataset(self._random_state) X, y = load_dataset(self._random_state)
clf = Stree(random_state=self._random_state, max_features=2) clf = Stree(random_state=self._random_state, max_features=2)
clf.fit(X, y) clf.fit(X, y)
self.assertAlmostEqual(0.944, clf.score(X, y)) self.assertAlmostEqual(0.9246666666666666, clf.score(X, y))
def test_bogus_splitter_parameter(self): def test_bogus_splitter_parameter(self):
clf = Stree(splitter="duck") clf = Stree(splitter="duck")