From 36816074ff9f5acf72fe50a9dcaf4f5c8f966f26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 10 Jan 2021 13:32:22 +0100 Subject: [PATCH] Combinatorial explosion (#19) * Remove itertools combinations from subspaces * Generates 5 random subspaces at most --- stree/Strees.py | 29 +++++++++++++++++++++-------- stree/tests/Splitter_test.py | 24 ++++++++++++++++-------- stree/tests/Stree_test.py | 2 +- 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/stree/Strees.py b/stree/Strees.py index f47775a..1eee533 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -10,8 +10,8 @@ import os import numbers import random import warnings -from math import log -from itertools import combinations +from math import log, factorial +from typing import Optional import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.svm import SVC, LinearSVC @@ -253,19 +253,32 @@ class Splitter: selected = feature_set return selected if selected is not None else feature_set + @staticmethod + def _generate_spaces(features: int, max_features: int) -> list: + comb = set() + # Generate at most 5 combinations + if max_features == features: + set_length = 1 + else: + number = factorial(features) / ( + factorial(max_features) * factorial(features - max_features) + ) + set_length = min(5, number) + while len(comb) < set_length: + comb.add( + tuple(sorted(random.sample(range(features), max_features))) + ) + return list(comb) + def _get_subspaces_set( self, dataset: np.array, labels: np.array, max_features: int ) -> np.array: - features = range(dataset.shape[1]) - features_sets = list(combinations(features, max_features)) + features_sets = self._generate_spaces(dataset.shape[1], max_features) if len(features_sets) > 1: if self._splitter_type == "random": index = random.randint(0, len(features_sets) - 1) return features_sets[index] else: - # get only 3 sets at most - if len(features_sets) > 3: - features_sets = random.sample(features_sets, 3) return self._select_best_set(dataset, labels, features_sets) else: return features_sets[0] @@ -488,7 +501,7 @@ class Stree(BaseEstimator, ClassifierMixin): sample_weight: np.ndarray, depth: int, title: str, - ) -> Snode: + ) -> Optional[Snode]: """Recursive function to split the original dataset into predictor nodes (leaves) diff --git a/stree/tests/Splitter_test.py b/stree/tests/Splitter_test.py index a0dbc96..c70039e 100644 --- a/stree/tests/Splitter_test.py +++ b/stree/tests/Splitter_test.py @@ -166,6 +166,14 @@ class Splitter_test(unittest.TestCase): self.assertEqual((6,), computed_data.shape) self.assertListEqual(expected.tolist(), computed_data.tolist()) + def test_generate_subspaces(self): + features = 250 + for max_features in range(2, features): + num = len(Splitter._generate_spaces(features, max_features)) + self.assertEqual(5, num) + self.assertEqual(3, len(Splitter._generate_spaces(3, 2))) + self.assertEqual(4, len(Splitter._generate_spaces(4, 3))) + def test_best_splitter_few_sets(self): X, y = load_iris(return_X_y=True) X = np.delete(X, 3, 1) @@ -176,14 +184,14 @@ class Splitter_test(unittest.TestCase): def test_splitter_parameter(self): expected_values = [ - [0, 1, 7, 9], # best entropy max_samples - [3, 8, 10, 11], # best entropy impurity - [0, 2, 8, 12], # best gini max_samples - [1, 2, 5, 12], # best gini impurity - [1, 2, 5, 10], # random entropy max_samples - [4, 8, 9, 12], # random entropy impurity - [3, 9, 11, 12], # random gini max_samples - [1, 5, 6, 9], # random gini impurity + [1, 4, 9, 12], # best entropy max_samples + [1, 3, 6, 10], # best entropy impurity + [6, 8, 10, 12], # best gini max_samples + [7, 8, 10, 11], # best gini impurity + [0, 3, 8, 12], # random entropy max_samples + [0, 3, 9, 11], # random entropy impurity + [0, 4, 7, 12], # random gini max_samples + [0, 2, 5, 6], # random gini impurity ] X, y = load_wine(return_X_y=True) rn = 0 diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 77fa82a..3ebfb70 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -313,7 +313,7 @@ class Stree_test(unittest.TestCase): X, y = load_dataset(self._random_state) clf = Stree(random_state=self._random_state, max_features=2) clf.fit(X, y) - self.assertAlmostEqual(0.944, clf.score(X, y)) + self.assertAlmostEqual(0.9246666666666666, clf.score(X, y)) def test_bogus_splitter_parameter(self): clf = Stree(splitter="duck")