mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 16:06:01 +00:00
Remove itertools combinations from subspaces
This commit is contained in:
@@ -11,7 +11,6 @@ import numbers
|
|||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from math import log
|
from math import log
|
||||||
from itertools import combinations
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||||
from sklearn.svm import SVC, LinearSVC
|
from sklearn.svm import SVC, LinearSVC
|
||||||
@@ -253,19 +252,26 @@ class Splitter:
|
|||||||
selected = feature_set
|
selected = feature_set
|
||||||
return selected if selected is not None else feature_set
|
return selected if selected is not None else feature_set
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_spaces(features: int, max_features: int) -> list:
|
||||||
|
comb = set()
|
||||||
|
# Generate at most 3 combinations
|
||||||
|
set_length = 1 if max_features == features else 3
|
||||||
|
while len(comb) < set_length:
|
||||||
|
comb.add(
|
||||||
|
tuple(sorted(random.sample(range(features), max_features)))
|
||||||
|
)
|
||||||
|
return list(comb)
|
||||||
|
|
||||||
def _get_subspaces_set(
|
def _get_subspaces_set(
|
||||||
self, dataset: np.array, labels: np.array, max_features: int
|
self, dataset: np.array, labels: np.array, max_features: int
|
||||||
) -> np.array:
|
) -> np.array:
|
||||||
features = range(dataset.shape[1])
|
features_sets = self._generate_spaces(dataset.shape[1], max_features)
|
||||||
features_sets = list(combinations(features, max_features))
|
|
||||||
if len(features_sets) > 1:
|
if len(features_sets) > 1:
|
||||||
if self._splitter_type == "random":
|
if self._splitter_type == "random":
|
||||||
index = random.randint(0, len(features_sets) - 1)
|
index = random.randint(0, len(features_sets) - 1)
|
||||||
return features_sets[index]
|
return features_sets[index]
|
||||||
else:
|
else:
|
||||||
# get only 3 sets at most
|
|
||||||
if len(features_sets) > 3:
|
|
||||||
features_sets = random.sample(features_sets, 3)
|
|
||||||
return self._select_best_set(dataset, labels, features_sets)
|
return self._select_best_set(dataset, labels, features_sets)
|
||||||
else:
|
else:
|
||||||
return features_sets[0]
|
return features_sets[0]
|
||||||
|
@@ -176,14 +176,14 @@ class Splitter_test(unittest.TestCase):
|
|||||||
|
|
||||||
def test_splitter_parameter(self):
|
def test_splitter_parameter(self):
|
||||||
expected_values = [
|
expected_values = [
|
||||||
[0, 1, 7, 9], # best entropy max_samples
|
[0, 4, 6, 12], # best entropy max_samples
|
||||||
[3, 8, 10, 11], # best entropy impurity
|
[1, 3, 6, 10], # best entropy impurity
|
||||||
[0, 2, 8, 12], # best gini max_samples
|
[0, 1, 5, 11], # best gini max_samples
|
||||||
[1, 2, 5, 12], # best gini impurity
|
[0, 1, 7, 9], # best gini impurity
|
||||||
[1, 2, 5, 10], # random entropy max_samples
|
[0, 4, 6, 8], # random entropy max_samples
|
||||||
[4, 8, 9, 12], # random entropy impurity
|
[4, 5, 8, 9], # random entropy impurity
|
||||||
[3, 9, 11, 12], # random gini max_samples
|
[0, 4, 10, 12], # random gini max_samples
|
||||||
[1, 5, 6, 9], # random gini impurity
|
[1, 5, 8, 12], # random gini impurity
|
||||||
]
|
]
|
||||||
X, y = load_wine(return_X_y=True)
|
X, y = load_wine(return_X_y=True)
|
||||||
rn = 0
|
rn = 0
|
||||||
|
@@ -313,7 +313,7 @@ class Stree_test(unittest.TestCase):
|
|||||||
X, y = load_dataset(self._random_state)
|
X, y = load_dataset(self._random_state)
|
||||||
clf = Stree(random_state=self._random_state, max_features=2)
|
clf = Stree(random_state=self._random_state, max_features=2)
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
self.assertAlmostEqual(0.944, clf.score(X, y))
|
self.assertAlmostEqual(0.9246666666666666, clf.score(X, y))
|
||||||
|
|
||||||
def test_bogus_splitter_parameter(self):
|
def test_bogus_splitter_parameter(self):
|
||||||
clf = Stree(splitter="duck")
|
clf = Stree(splitter="duck")
|
||||||
|
Reference in New Issue
Block a user