Combinatorial explosion (#19)

* Remove itertools combinations from subspaces

* Generates 5 random subspaces at most
This commit is contained in:
Ricardo Montañana Gómez
2021-01-10 13:32:22 +01:00
committed by GitHub
parent 475ad7e752
commit 36816074ff
3 changed files with 38 additions and 17 deletions

View File

@@ -10,8 +10,8 @@ import os
import numbers import numbers
import random import random
import warnings import warnings
from math import log from math import log, factorial
from itertools import combinations from typing import Optional
import numpy as np import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC, LinearSVC from sklearn.svm import SVC, LinearSVC
@@ -253,19 +253,32 @@ class Splitter:
selected = feature_set selected = feature_set
return selected if selected is not None else feature_set return selected if selected is not None else feature_set
@staticmethod
def _generate_spaces(features: int, max_features: int) -> list:
comb = set()
# Generate at most 5 combinations
if max_features == features:
set_length = 1
else:
number = factorial(features) / (
factorial(max_features) * factorial(features - max_features)
)
set_length = min(5, number)
while len(comb) < set_length:
comb.add(
tuple(sorted(random.sample(range(features), max_features)))
)
return list(comb)
def _get_subspaces_set( def _get_subspaces_set(
self, dataset: np.array, labels: np.array, max_features: int self, dataset: np.array, labels: np.array, max_features: int
) -> np.array: ) -> np.array:
features = range(dataset.shape[1]) features_sets = self._generate_spaces(dataset.shape[1], max_features)
features_sets = list(combinations(features, max_features))
if len(features_sets) > 1: if len(features_sets) > 1:
if self._splitter_type == "random": if self._splitter_type == "random":
index = random.randint(0, len(features_sets) - 1) index = random.randint(0, len(features_sets) - 1)
return features_sets[index] return features_sets[index]
else: else:
# get only 3 sets at most
if len(features_sets) > 3:
features_sets = random.sample(features_sets, 3)
return self._select_best_set(dataset, labels, features_sets) return self._select_best_set(dataset, labels, features_sets)
else: else:
return features_sets[0] return features_sets[0]
@@ -488,7 +501,7 @@ class Stree(BaseEstimator, ClassifierMixin):
sample_weight: np.ndarray, sample_weight: np.ndarray,
depth: int, depth: int,
title: str, title: str,
) -> Snode: ) -> Optional[Snode]:
"""Recursive function to split the original dataset into predictor """Recursive function to split the original dataset into predictor
nodes (leaves) nodes (leaves)

View File

@@ -166,6 +166,14 @@ class Splitter_test(unittest.TestCase):
self.assertEqual((6,), computed_data.shape) self.assertEqual((6,), computed_data.shape)
self.assertListEqual(expected.tolist(), computed_data.tolist()) self.assertListEqual(expected.tolist(), computed_data.tolist())
def test_generate_subspaces(self):
features = 250
for max_features in range(2, features):
num = len(Splitter._generate_spaces(features, max_features))
self.assertEqual(5, num)
self.assertEqual(3, len(Splitter._generate_spaces(3, 2)))
self.assertEqual(4, len(Splitter._generate_spaces(4, 3)))
def test_best_splitter_few_sets(self): def test_best_splitter_few_sets(self):
X, y = load_iris(return_X_y=True) X, y = load_iris(return_X_y=True)
X = np.delete(X, 3, 1) X = np.delete(X, 3, 1)
@@ -176,14 +184,14 @@ class Splitter_test(unittest.TestCase):
def test_splitter_parameter(self): def test_splitter_parameter(self):
expected_values = [ expected_values = [
[0, 1, 7, 9], # best entropy max_samples [1, 4, 9, 12], # best entropy max_samples
[3, 8, 10, 11], # best entropy impurity [1, 3, 6, 10], # best entropy impurity
[0, 2, 8, 12], # best gini max_samples [6, 8, 10, 12], # best gini max_samples
[1, 2, 5, 12], # best gini impurity [7, 8, 10, 11], # best gini impurity
[1, 2, 5, 10], # random entropy max_samples [0, 3, 8, 12], # random entropy max_samples
[4, 8, 9, 12], # random entropy impurity [0, 3, 9, 11], # random entropy impurity
[3, 9, 11, 12], # random gini max_samples [0, 4, 7, 12], # random gini max_samples
[1, 5, 6, 9], # random gini impurity [0, 2, 5, 6], # random gini impurity
] ]
X, y = load_wine(return_X_y=True) X, y = load_wine(return_X_y=True)
rn = 0 rn = 0

View File

@@ -313,7 +313,7 @@ class Stree_test(unittest.TestCase):
X, y = load_dataset(self._random_state) X, y = load_dataset(self._random_state)
clf = Stree(random_state=self._random_state, max_features=2) clf = Stree(random_state=self._random_state, max_features=2)
clf.fit(X, y) clf.fit(X, y)
self.assertAlmostEqual(0.944, clf.score(X, y)) self.assertAlmostEqual(0.9246666666666666, clf.score(X, y))
def test_bogus_splitter_parameter(self): def test_bogus_splitter_parameter(self):
clf = Stree(splitter="duck") clf = Stree(splitter="duck")