Generates 5 random subspaces at most

This commit is contained in:
2020-12-11 11:32:54 +01:00
parent 3f01234ebf
commit 6896f76ca9
2 changed files with 26 additions and 11 deletions

View File

@@ -10,7 +10,8 @@ import os
import numbers import numbers
import random import random
import warnings import warnings
from math import log from math import log, factorial
from typing import Optional
import numpy as np import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC, LinearSVC from sklearn.svm import SVC, LinearSVC
@@ -255,8 +256,14 @@ class Splitter:
@staticmethod @staticmethod
def _generate_spaces(features: int, max_features: int) -> list: def _generate_spaces(features: int, max_features: int) -> list:
comb = set() comb = set()
# Generate at most 3 combinations # Generate at most 5 combinations
set_length = 1 if max_features == features else 3 if max_features == features:
set_length = 1
else:
number = factorial(features) / (
factorial(max_features) * factorial(features - max_features)
)
set_length = min(5, number)
while len(comb) < set_length: while len(comb) < set_length:
comb.add( comb.add(
tuple(sorted(random.sample(range(features), max_features))) tuple(sorted(random.sample(range(features), max_features)))
@@ -494,7 +501,7 @@ class Stree(BaseEstimator, ClassifierMixin):
sample_weight: np.ndarray, sample_weight: np.ndarray,
depth: int, depth: int,
title: str, title: str,
) -> Snode: ) -> Optional[Snode]:
"""Recursive function to split the original dataset into predictor """Recursive function to split the original dataset into predictor
nodes (leaves) nodes (leaves)

View File

@@ -166,6 +166,14 @@ class Splitter_test(unittest.TestCase):
self.assertEqual((6,), computed_data.shape) self.assertEqual((6,), computed_data.shape)
self.assertListEqual(expected.tolist(), computed_data.tolist()) self.assertListEqual(expected.tolist(), computed_data.tolist())
def test_generate_subspaces(self):
features = 250
for max_features in range(2, features):
num = len(Splitter._generate_spaces(features, max_features))
self.assertEqual(5, num)
self.assertEqual(3, len(Splitter._generate_spaces(3, 2)))
self.assertEqual(4, len(Splitter._generate_spaces(4, 3)))
def test_best_splitter_few_sets(self): def test_best_splitter_few_sets(self):
X, y = load_iris(return_X_y=True) X, y = load_iris(return_X_y=True)
X = np.delete(X, 3, 1) X = np.delete(X, 3, 1)
@@ -176,14 +184,14 @@ class Splitter_test(unittest.TestCase):
def test_splitter_parameter(self): def test_splitter_parameter(self):
expected_values = [ expected_values = [
[0, 4, 6, 12], # best entropy max_samples [1, 4, 9, 12], # best entropy max_samples
[1, 3, 6, 10], # best entropy impurity [1, 3, 6, 10], # best entropy impurity
[0, 1, 5, 11], # best gini max_samples [6, 8, 10, 12], # best gini max_samples
[0, 1, 7, 9], # best gini impurity [7, 8, 10, 11], # best gini impurity
[0, 4, 6, 8], # random entropy max_samples [0, 3, 8, 12], # random entropy max_samples
[4, 5, 8, 9], # random entropy impurity [0, 3, 9, 11], # random entropy impurity
[0, 4, 10, 12], # random gini max_samples [0, 4, 7, 12], # random gini max_samples
[1, 5, 8, 12], # random gini impurity [0, 2, 5, 6], # random gini impurity
] ]
X, y = load_wine(return_X_y=True) X, y = load_wine(return_X_y=True)
rn = 0 rn = 0