Generates 5 random subspaces at most

This commit is contained in:
2020-12-11 11:32:54 +01:00
parent 3f01234ebf
commit 6896f76ca9
2 changed files with 26 additions and 11 deletions

View File

@@ -10,7 +10,8 @@ import os
import numbers
import random
import warnings
from math import log
from math import log, factorial
from typing import Optional
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC, LinearSVC
@@ -255,8 +256,14 @@ class Splitter:
@staticmethod
def _generate_spaces(features: int, max_features: int) -> list:
comb = set()
# Generate at most 3 combinations
set_length = 1 if max_features == features else 3
# Generate at most 5 combinations
if max_features == features:
set_length = 1
else:
number = factorial(features) / (
factorial(max_features) * factorial(features - max_features)
)
set_length = min(5, number)
while len(comb) < set_length:
comb.add(
tuple(sorted(random.sample(range(features), max_features)))
@@ -494,7 +501,7 @@ class Stree(BaseEstimator, ClassifierMixin):
sample_weight: np.ndarray,
depth: int,
title: str,
) -> Snode:
) -> Optional[Snode]:
"""Recursive function to split the original dataset into predictor
nodes (leaves)

View File

@@ -166,6 +166,14 @@ class Splitter_test(unittest.TestCase):
self.assertEqual((6,), computed_data.shape)
self.assertListEqual(expected.tolist(), computed_data.tolist())
def test_generate_subspaces(self):
features = 250
for max_features in range(2, features):
num = len(Splitter._generate_spaces(features, max_features))
self.assertEqual(5, num)
self.assertEqual(3, len(Splitter._generate_spaces(3, 2)))
self.assertEqual(4, len(Splitter._generate_spaces(4, 3)))
def test_best_splitter_few_sets(self):
X, y = load_iris(return_X_y=True)
X = np.delete(X, 3, 1)
@@ -176,14 +184,14 @@ class Splitter_test(unittest.TestCase):
def test_splitter_parameter(self):
expected_values = [
[0, 4, 6, 12], # best entropy max_samples
[1, 4, 9, 12], # best entropy max_samples
[1, 3, 6, 10], # best entropy impurity
[0, 1, 5, 11], # best gini max_samples
[0, 1, 7, 9], # best gini impurity
[0, 4, 6, 8], # random entropy max_samples
[4, 5, 8, 9], # random entropy impurity
[0, 4, 10, 12], # random gini max_samples
[1, 5, 8, 12], # random gini impurity
[6, 8, 10, 12], # best gini max_samples
[7, 8, 10, 11], # best gini impurity
[0, 3, 8, 12], # random entropy max_samples
[0, 3, 9, 11], # random entropy impurity
[0, 4, 7, 12], # random gini max_samples
[0, 2, 5, 6], # random gini impurity
]
X, y = load_wine(return_X_y=True)
rn = 0