mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
Generates 5 random subspaces at most
This commit is contained in:
@@ -10,7 +10,8 @@ import os
|
||||
import numbers
|
||||
import random
|
||||
import warnings
|
||||
from math import log
|
||||
from math import log, factorial
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.svm import SVC, LinearSVC
|
||||
@@ -255,8 +256,14 @@ class Splitter:
|
||||
@staticmethod
|
||||
def _generate_spaces(features: int, max_features: int) -> list:
|
||||
comb = set()
|
||||
# Generate at most 3 combinations
|
||||
set_length = 1 if max_features == features else 3
|
||||
# Generate at most 5 combinations
|
||||
if max_features == features:
|
||||
set_length = 1
|
||||
else:
|
||||
number = factorial(features) / (
|
||||
factorial(max_features) * factorial(features - max_features)
|
||||
)
|
||||
set_length = min(5, number)
|
||||
while len(comb) < set_length:
|
||||
comb.add(
|
||||
tuple(sorted(random.sample(range(features), max_features)))
|
||||
@@ -494,7 +501,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
sample_weight: np.ndarray,
|
||||
depth: int,
|
||||
title: str,
|
||||
) -> Snode:
|
||||
) -> Optional[Snode]:
|
||||
"""Recursive function to split the original dataset into predictor
|
||||
nodes (leaves)
|
||||
|
||||
|
@@ -166,6 +166,14 @@ class Splitter_test(unittest.TestCase):
|
||||
self.assertEqual((6,), computed_data.shape)
|
||||
self.assertListEqual(expected.tolist(), computed_data.tolist())
|
||||
|
||||
def test_generate_subspaces(self):
|
||||
features = 250
|
||||
for max_features in range(2, features):
|
||||
num = len(Splitter._generate_spaces(features, max_features))
|
||||
self.assertEqual(5, num)
|
||||
self.assertEqual(3, len(Splitter._generate_spaces(3, 2)))
|
||||
self.assertEqual(4, len(Splitter._generate_spaces(4, 3)))
|
||||
|
||||
def test_best_splitter_few_sets(self):
|
||||
X, y = load_iris(return_X_y=True)
|
||||
X = np.delete(X, 3, 1)
|
||||
@@ -176,14 +184,14 @@ class Splitter_test(unittest.TestCase):
|
||||
|
||||
def test_splitter_parameter(self):
|
||||
expected_values = [
|
||||
[0, 4, 6, 12], # best entropy max_samples
|
||||
[1, 4, 9, 12], # best entropy max_samples
|
||||
[1, 3, 6, 10], # best entropy impurity
|
||||
[0, 1, 5, 11], # best gini max_samples
|
||||
[0, 1, 7, 9], # best gini impurity
|
||||
[0, 4, 6, 8], # random entropy max_samples
|
||||
[4, 5, 8, 9], # random entropy impurity
|
||||
[0, 4, 10, 12], # random gini max_samples
|
||||
[1, 5, 8, 12], # random gini impurity
|
||||
[6, 8, 10, 12], # best gini max_samples
|
||||
[7, 8, 10, 11], # best gini impurity
|
||||
[0, 3, 8, 12], # random entropy max_samples
|
||||
[0, 3, 9, 11], # random entropy impurity
|
||||
[0, 4, 7, 12], # random gini max_samples
|
||||
[0, 2, 5, 6], # random gini impurity
|
||||
]
|
||||
X, y = load_wine(return_X_y=True)
|
||||
rn = 0
|
||||
|
Reference in New Issue
Block a user