From ca773d3537864c7331cf332853621bdca7954bf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Fri, 11 Dec 2020 11:29:45 +0100 Subject: [PATCH] Solve combinatorial explosion --- .pre-commit-config.yaml | 24 ++++++++++++------------ odte/Odte.py | 22 +++++++++++++++++++--- odte/tests/Odte_tests.py | 28 ++++++++++++++++++++-------- 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 97b54d1..cb9a435 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,19 +1,19 @@ repos: -- repo: https://github.com/ambv/black + - repo: https://github.com/ambv/black rev: stable hooks: - - id: black - language_version: python3.8 -- repo: https://gitlab.com/pycqa/flake8 + - id: black + language_version: python3.8 + - repo: https://gitlab.com/pycqa/flake8 rev: 3.8.3 hooks: - - id: flake8 -- repo: https://github.com/pre-commit/mirrors-mypy - rev: '' # Use the sha / tag you want to point at - hooks: - - id: mypy - args: [--strict, --ignore-missing-imports] -- repo: local + - id: flake8 + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: '' # Use the sha / tag you want to point at + # hooks: + # - id: mypy + # args: [--strict, --ignore-missing-imports] + - repo: local hooks: - id: unittest name: unittest @@ -24,4 +24,4 @@ repos: name: coverage entry: python -m coverage report -m --fail-under=100 language: system - pass_filenames: false \ No newline at end of file + pass_filenames: false diff --git a/odte/Odte.py b/odte/Odte.py index 0610d63..8e9ffe6 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -8,8 +8,8 @@ Build a forest of oblique trees based on STree from __future__ import annotations import random import sys +from math import factorial from typing import Union, Optional, Tuple, List -from itertools import combinations import numpy as np from sklearn.utils.multiclass import check_classification_targets from sklearn.base import clone, BaseEstimator, ClassifierMixin @@ -189,12 +189,28 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore ) return max_features + @staticmethod + def _generate_spaces(features: int, max_features: int) -> list: + comb = set() + # Generate at most 5 combinations + if max_features == features: + set_length = 1 + else: + number = factorial(features) / ( + factorial(max_features) * factorial(features - max_features) + ) + set_length = min(5, number) + while len(comb) < set_length: + comb.add( + tuple(sorted(random.sample(range(features), max_features))) + ) + return list(comb) + @staticmethod def _get_random_subspace( dataset: np.array, labels: np.array, max_features: int ) -> Tuple[int, ...]: - features = range(dataset.shape[1]) - features_sets = list(combinations(features, max_features)) + features_sets = Odte._generate_spaces(dataset.shape[1], max_features) if len(features_sets) > 1: index = random.randint(0, len(features_sets) - 1) return features_sets[index] diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index b02e8a1..e1c7a3e 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -1,7 +1,7 @@ # type: ignore import unittest import os - +import random import warnings from sklearn.exceptions import ConvergenceWarning @@ -30,13 +30,13 @@ class Odte_test(unittest.TestCase): def test_initialize_max_feature(self): expected_values = [ - [0, 5, 6, 15], - [0, 2, 3, 9, 11, 14], + [6, 7, 8, 15], + [3, 4, 5, 6, 10, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - [0, 5, 6, 15], - [0, 5, 6, 15], - [0, 5, 6, 15], + [6, 7, 8, 15], + [6, 7, 8, 15], + [6, 7, 8, 15], ] X, y = load_dataset( random_state=self._random_state, n_features=16, n_samples=10 @@ -49,6 +49,7 @@ class Odte_test(unittest.TestCase): computed = tclf._get_random_subspace(X, y, tclf.max_features_) expected = expected_values.pop(0) self.assertListEqual(expected, list(computed)) + # print(f"{list(computed)},") def test_initialize_random(self): expected = [37, 235, 908] @@ -128,11 +129,12 @@ class Odte_test(unittest.TestCase): def test_score_splitter_max_features(self): X, y = load_dataset(self._random_state, n_features=12, n_samples=150) results = [ - 1.0, - 1.0, + 0.86, + 0.8933333333333333, 0.9933333333333333, 0.9933333333333333, ] + random.seed(self._random_state) for max_features in ["auto", None]: for splitter in ["best", "random"]: tclf = Odte( @@ -143,12 +145,22 @@ class Odte_test(unittest.TestCase): tclf.set_params( **dict( base_estimator__splitter=splitter, + base_estimator__random_state=self._random_state, ) ) expected = results.pop(0) computed = tclf.fit(X, y).score(X, y) + # print(computed, splitter, max_features) self.assertAlmostEqual(expected, computed) + def test_generate_subspaces(self): + features = 250 + for max_features in range(2, features): + num = len(Odte._generate_spaces(features, max_features)) + self.assertEqual(5, num) + self.assertEqual(3, len(Odte._generate_spaces(3, 2))) + self.assertEqual(4, len(Odte._generate_spaces(4, 3))) + @staticmethod def test_is_a_sklearn_classifier(): os.environ["PYTHONWARNINGS"] = "ignore"