mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 17:06:01 +00:00
Combinatorial explosion (#19)
* Remove itertools combinations from subspaces * Generates 5 random subspaces at most
This commit is contained in:
committed by
GitHub
parent
475ad7e752
commit
36816074ff
@@ -10,8 +10,8 @@ import os
|
||||
import numbers
|
||||
import random
|
||||
import warnings
|
||||
from math import log
|
||||
from itertools import combinations
|
||||
from math import log, factorial
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.svm import SVC, LinearSVC
|
||||
@@ -253,19 +253,32 @@ class Splitter:
|
||||
selected = feature_set
|
||||
return selected if selected is not None else feature_set
|
||||
|
||||
@staticmethod
|
||||
def _generate_spaces(features: int, max_features: int) -> list:
|
||||
comb = set()
|
||||
# Generate at most 5 combinations
|
||||
if max_features == features:
|
||||
set_length = 1
|
||||
else:
|
||||
number = factorial(features) / (
|
||||
factorial(max_features) * factorial(features - max_features)
|
||||
)
|
||||
set_length = min(5, number)
|
||||
while len(comb) < set_length:
|
||||
comb.add(
|
||||
tuple(sorted(random.sample(range(features), max_features)))
|
||||
)
|
||||
return list(comb)
|
||||
|
||||
def _get_subspaces_set(
|
||||
self, dataset: np.array, labels: np.array, max_features: int
|
||||
) -> np.array:
|
||||
features = range(dataset.shape[1])
|
||||
features_sets = list(combinations(features, max_features))
|
||||
features_sets = self._generate_spaces(dataset.shape[1], max_features)
|
||||
if len(features_sets) > 1:
|
||||
if self._splitter_type == "random":
|
||||
index = random.randint(0, len(features_sets) - 1)
|
||||
return features_sets[index]
|
||||
else:
|
||||
# get only 3 sets at most
|
||||
if len(features_sets) > 3:
|
||||
features_sets = random.sample(features_sets, 3)
|
||||
return self._select_best_set(dataset, labels, features_sets)
|
||||
else:
|
||||
return features_sets[0]
|
||||
@@ -488,7 +501,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
sample_weight: np.ndarray,
|
||||
depth: int,
|
||||
title: str,
|
||||
) -> Snode:
|
||||
) -> Optional[Snode]:
|
||||
"""Recursive function to split the original dataset into predictor
|
||||
nodes (leaves)
|
||||
|
||||
|
Reference in New Issue
Block a user