From 6896f76ca9925a3b20935f85e6a0a035534d7204 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Fri, 11 Dec 2020 11:32:54 +0100 Subject: [PATCH] Generates 5 random subspaces at most --- stree/Strees.py | 15 +++++++++++---- stree/tests/Splitter_test.py | 22 +++++++++++++++------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/stree/Strees.py b/stree/Strees.py index e04bee3..1eee533 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -10,7 +10,8 @@ import os import numbers import random import warnings -from math import log +from math import log, factorial +from typing import Optional import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.svm import SVC, LinearSVC @@ -255,8 +256,14 @@ class Splitter: @staticmethod def _generate_spaces(features: int, max_features: int) -> list: comb = set() - # Generate at most 3 combinations - set_length = 1 if max_features == features else 3 + # Generate at most 5 combinations + if max_features == features: + set_length = 1 + else: + number = factorial(features) / ( + factorial(max_features) * factorial(features - max_features) + ) + set_length = min(5, number) while len(comb) < set_length: comb.add( tuple(sorted(random.sample(range(features), max_features))) @@ -494,7 +501,7 @@ class Stree(BaseEstimator, ClassifierMixin): sample_weight: np.ndarray, depth: int, title: str, - ) -> Snode: + ) -> Optional[Snode]: """Recursive function to split the original dataset into predictor nodes (leaves) diff --git a/stree/tests/Splitter_test.py b/stree/tests/Splitter_test.py index 6a5e4f8..c70039e 100644 --- a/stree/tests/Splitter_test.py +++ b/stree/tests/Splitter_test.py @@ -166,6 +166,14 @@ class Splitter_test(unittest.TestCase): self.assertEqual((6,), computed_data.shape) self.assertListEqual(expected.tolist(), computed_data.tolist()) + def test_generate_subspaces(self): + features = 250 + for max_features in range(2, features): + num = len(Splitter._generate_spaces(features, max_features)) + self.assertEqual(5, num) + self.assertEqual(3, len(Splitter._generate_spaces(3, 2))) + self.assertEqual(4, len(Splitter._generate_spaces(4, 3))) + def test_best_splitter_few_sets(self): X, y = load_iris(return_X_y=True) X = np.delete(X, 3, 1) @@ -176,14 +184,14 @@ class Splitter_test(unittest.TestCase): def test_splitter_parameter(self): expected_values = [ - [0, 4, 6, 12], # best entropy max_samples + [1, 4, 9, 12], # best entropy max_samples [1, 3, 6, 10], # best entropy impurity - [0, 1, 5, 11], # best gini max_samples - [0, 1, 7, 9], # best gini impurity - [0, 4, 6, 8], # random entropy max_samples - [4, 5, 8, 9], # random entropy impurity - [0, 4, 10, 12], # random gini max_samples - [1, 5, 8, 12], # random gini impurity + [6, 8, 10, 12], # best gini max_samples + [7, 8, 10, 11], # best gini impurity + [0, 3, 8, 12], # random entropy max_samples + [0, 3, 9, 11], # random entropy impurity + [0, 4, 7, 12], # random gini max_samples + [0, 2, 5, 6], # random gini impurity ] X, y = load_wine(return_X_y=True) rn = 0