Solve combinatorial explosion

This commit is contained in:
Ricardo Montañana Gómez 2020-12-11 11:29:45 +01:00
parent 7a49d672df
commit ca773d3537
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 51 additions and 23 deletions

View File

@ -1,19 +1,19 @@
repos:
- repo: https://github.com/ambv/black
- repo: https://github.com/ambv/black
rev: stable
hooks:
- id: black
language_version: python3.8
- repo: https://gitlab.com/pycqa/flake8
- id: black
language_version: python3.8
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.3
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: '' # Use the sha / tag you want to point at
hooks:
- id: mypy
args: [--strict, --ignore-missing-imports]
- repo: local
- id: flake8
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: '' # Use the sha / tag you want to point at
# hooks:
# - id: mypy
# args: [--strict, --ignore-missing-imports]
- repo: local
hooks:
- id: unittest
name: unittest

View File

@ -8,8 +8,8 @@ Build a forest of oblique trees based on STree
from __future__ import annotations
import random
import sys
from math import factorial
from typing import Union, Optional, Tuple, List
from itertools import combinations
import numpy as np
from sklearn.utils.multiclass import check_classification_targets
from sklearn.base import clone, BaseEstimator, ClassifierMixin
@ -189,12 +189,28 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
)
return max_features
@staticmethod
def _generate_spaces(features: int, max_features: int) -> list:
comb = set()
# Generate at most 5 combinations
if max_features == features:
set_length = 1
else:
number = factorial(features) / (
factorial(max_features) * factorial(features - max_features)
)
set_length = min(5, number)
while len(comb) < set_length:
comb.add(
tuple(sorted(random.sample(range(features), max_features)))
)
return list(comb)
@staticmethod
def _get_random_subspace(
dataset: np.array, labels: np.array, max_features: int
) -> Tuple[int, ...]:
features = range(dataset.shape[1])
features_sets = list(combinations(features, max_features))
features_sets = Odte._generate_spaces(dataset.shape[1], max_features)
if len(features_sets) > 1:
index = random.randint(0, len(features_sets) - 1)
return features_sets[index]

View File

@ -1,7 +1,7 @@
# type: ignore
import unittest
import os
import random
import warnings
from sklearn.exceptions import ConvergenceWarning
@ -30,13 +30,13 @@ class Odte_test(unittest.TestCase):
def test_initialize_max_feature(self):
expected_values = [
[0, 5, 6, 15],
[0, 2, 3, 9, 11, 14],
[6, 7, 8, 15],
[3, 4, 5, 6, 10, 13],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[0, 5, 6, 15],
[0, 5, 6, 15],
[0, 5, 6, 15],
[6, 7, 8, 15],
[6, 7, 8, 15],
[6, 7, 8, 15],
]
X, y = load_dataset(
random_state=self._random_state, n_features=16, n_samples=10
@ -49,6 +49,7 @@ class Odte_test(unittest.TestCase):
computed = tclf._get_random_subspace(X, y, tclf.max_features_)
expected = expected_values.pop(0)
self.assertListEqual(expected, list(computed))
# print(f"{list(computed)},")
def test_initialize_random(self):
expected = [37, 235, 908]
@ -128,11 +129,12 @@ class Odte_test(unittest.TestCase):
def test_score_splitter_max_features(self):
X, y = load_dataset(self._random_state, n_features=12, n_samples=150)
results = [
1.0,
1.0,
0.86,
0.8933333333333333,
0.9933333333333333,
0.9933333333333333,
]
random.seed(self._random_state)
for max_features in ["auto", None]:
for splitter in ["best", "random"]:
tclf = Odte(
@ -143,12 +145,22 @@ class Odte_test(unittest.TestCase):
tclf.set_params(
**dict(
base_estimator__splitter=splitter,
base_estimator__random_state=self._random_state,
)
)
expected = results.pop(0)
computed = tclf.fit(X, y).score(X, y)
# print(computed, splitter, max_features)
self.assertAlmostEqual(expected, computed)
def test_generate_subspaces(self):
features = 250
for max_features in range(2, features):
num = len(Odte._generate_spaces(features, max_features))
self.assertEqual(5, num)
self.assertEqual(3, len(Odte._generate_spaces(3, 2)))
self.assertEqual(4, len(Odte._generate_spaces(4, 3)))
@staticmethod
def test_is_a_sklearn_classifier():
os.environ["PYTHONWARNINGS"] = "ignore"