mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 17:06:01 +00:00
committed by
GitHub
parent
a4aac9d310
commit
02de394c96
@@ -15,6 +15,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.svm import SVC, LinearSVC
|
||||
from sklearn.feature_selection import SelectKBest
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.utils import check_consistent_length
|
||||
from sklearn.utils.multiclass import check_classification_targets
|
||||
@@ -179,7 +180,7 @@ class Splitter:
|
||||
self,
|
||||
clf: SVC = None,
|
||||
criterion: str = None,
|
||||
splitter_type: str = None,
|
||||
feature_select: str = None,
|
||||
criteria: str = None,
|
||||
min_samples_split: int = None,
|
||||
random_state=None,
|
||||
@@ -192,7 +193,7 @@ class Splitter:
|
||||
self._criterion = criterion
|
||||
self._min_samples_split = min_samples_split
|
||||
self._criteria = criteria
|
||||
self._splitter_type = splitter_type
|
||||
self._feature_select = feature_select
|
||||
self._normalize = normalize
|
||||
|
||||
if clf is None:
|
||||
@@ -211,9 +212,10 @@ class Splitter:
|
||||
f"criteria has to be max_samples or impurity; got ({criteria})"
|
||||
)
|
||||
|
||||
if splitter_type not in ["random", "best"]:
|
||||
if feature_select not in ["random", "best"]:
|
||||
raise ValueError(
|
||||
f"splitter must be either random or best, got({splitter_type})"
|
||||
"splitter must be either random or best, got "
|
||||
f"({feature_select})"
|
||||
)
|
||||
self.criterion_function = getattr(self, f"_{self._criterion}")
|
||||
self.decision_criteria = getattr(self, f"_{self._criteria}")
|
||||
@@ -330,13 +332,10 @@ class Splitter:
|
||||
"""
|
||||
comb = set()
|
||||
# Generate at most 5 combinations
|
||||
if max_features == features:
|
||||
set_length = 1
|
||||
else:
|
||||
number = factorial(features) / (
|
||||
factorial(max_features) * factorial(features - max_features)
|
||||
)
|
||||
set_length = min(5, number)
|
||||
number = factorial(features) / (
|
||||
factorial(max_features) * factorial(features - max_features)
|
||||
)
|
||||
set_length = min(5, number)
|
||||
while len(comb) < set_length:
|
||||
comb.add(
|
||||
tuple(sorted(random.sample(range(features), max_features)))
|
||||
@@ -345,9 +344,9 @@ class Splitter:
|
||||
|
||||
def _get_subspaces_set(
|
||||
self, dataset: np.array, labels: np.array, max_features: int
|
||||
) -> np.array:
|
||||
) -> tuple:
|
||||
"""Compute the indices of the features selected by splitter depending
|
||||
on the self._splitter_type hyper parameter
|
||||
on the self._feature_select hyper parameter
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -361,23 +360,28 @@ class Splitter:
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.array
|
||||
tuple
|
||||
indices of the features selected
|
||||
"""
|
||||
features_sets = self._generate_spaces(dataset.shape[1], max_features)
|
||||
if len(features_sets) > 1:
|
||||
if self._splitter_type == "random":
|
||||
index = random.randint(0, len(features_sets) - 1)
|
||||
return features_sets[index]
|
||||
else:
|
||||
return self._select_best_set(dataset, labels, features_sets)
|
||||
else:
|
||||
return features_sets[0]
|
||||
if dataset.shape[1] == max_features:
|
||||
# No feature reduction applies
|
||||
return tuple(range(dataset.shape[1]))
|
||||
if self._feature_select == "random":
|
||||
features_sets = self._generate_spaces(
|
||||
dataset.shape[1], max_features
|
||||
)
|
||||
return self._select_best_set(dataset, labels, features_sets)
|
||||
# Take KBest features
|
||||
return (
|
||||
SelectKBest(k=max_features)
|
||||
.fit(dataset, labels)
|
||||
.get_support(indices=True)
|
||||
)
|
||||
|
||||
def get_subspace(
|
||||
self, dataset: np.array, labels: np.array, max_features: int
|
||||
) -> tuple:
|
||||
"""Return a subspace of the selected dataset of max_features length.
|
||||
"""Re3turn a subspace of the selected dataset of max_features length.
|
||||
Depending on hyperparmeter
|
||||
|
||||
Parameters
|
||||
@@ -613,7 +617,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.splitter_ = Splitter(
|
||||
clf=self._build_clf(),
|
||||
criterion=self.criterion,
|
||||
splitter_type=self.splitter,
|
||||
feature_select=self.splitter,
|
||||
criteria=self.split_criteria,
|
||||
random_state=self.random_state,
|
||||
min_samples_split=self.min_samples_split,
|
||||
|
Reference in New Issue
Block a user