mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 16:22:00 +00:00
Solve combinatorial explosion
This commit is contained in:
parent
7a49d672df
commit
ca773d3537
@ -1,19 +1,19 @@
|
||||
repos:
|
||||
- repo: https://github.com/ambv/black
|
||||
- repo: https://github.com/ambv/black
|
||||
rev: stable
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.8
|
||||
- repo: https://gitlab.com/pycqa/flake8
|
||||
- id: black
|
||||
language_version: python3.8
|
||||
- repo: https://gitlab.com/pycqa/flake8
|
||||
rev: 3.8.3
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: '' # Use the sha / tag you want to point at
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--strict, --ignore-missing-imports]
|
||||
- repo: local
|
||||
- id: flake8
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: '' # Use the sha / tag you want to point at
|
||||
# hooks:
|
||||
# - id: mypy
|
||||
# args: [--strict, --ignore-missing-imports]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: unittest
|
||||
name: unittest
|
||||
|
22
odte/Odte.py
22
odte/Odte.py
@ -8,8 +8,8 @@ Build a forest of oblique trees based on STree
|
||||
from __future__ import annotations
|
||||
import random
|
||||
import sys
|
||||
from math import factorial
|
||||
from typing import Union, Optional, Tuple, List
|
||||
from itertools import combinations
|
||||
import numpy as np
|
||||
from sklearn.utils.multiclass import check_classification_targets
|
||||
from sklearn.base import clone, BaseEstimator, ClassifierMixin
|
||||
@ -189,12 +189,28 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
|
||||
)
|
||||
return max_features
|
||||
|
||||
@staticmethod
|
||||
def _generate_spaces(features: int, max_features: int) -> list:
|
||||
comb = set()
|
||||
# Generate at most 5 combinations
|
||||
if max_features == features:
|
||||
set_length = 1
|
||||
else:
|
||||
number = factorial(features) / (
|
||||
factorial(max_features) * factorial(features - max_features)
|
||||
)
|
||||
set_length = min(5, number)
|
||||
while len(comb) < set_length:
|
||||
comb.add(
|
||||
tuple(sorted(random.sample(range(features), max_features)))
|
||||
)
|
||||
return list(comb)
|
||||
|
||||
@staticmethod
|
||||
def _get_random_subspace(
|
||||
dataset: np.array, labels: np.array, max_features: int
|
||||
) -> Tuple[int, ...]:
|
||||
features = range(dataset.shape[1])
|
||||
features_sets = list(combinations(features, max_features))
|
||||
features_sets = Odte._generate_spaces(dataset.shape[1], max_features)
|
||||
if len(features_sets) > 1:
|
||||
index = random.randint(0, len(features_sets) - 1)
|
||||
return features_sets[index]
|
||||
|
@ -1,7 +1,7 @@
|
||||
# type: ignore
|
||||
import unittest
|
||||
import os
|
||||
|
||||
import random
|
||||
import warnings
|
||||
from sklearn.exceptions import ConvergenceWarning
|
||||
|
||||
@ -30,13 +30,13 @@ class Odte_test(unittest.TestCase):
|
||||
|
||||
def test_initialize_max_feature(self):
|
||||
expected_values = [
|
||||
[0, 5, 6, 15],
|
||||
[0, 2, 3, 9, 11, 14],
|
||||
[6, 7, 8, 15],
|
||||
[3, 4, 5, 6, 10, 13],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||
[0, 5, 6, 15],
|
||||
[0, 5, 6, 15],
|
||||
[0, 5, 6, 15],
|
||||
[6, 7, 8, 15],
|
||||
[6, 7, 8, 15],
|
||||
[6, 7, 8, 15],
|
||||
]
|
||||
X, y = load_dataset(
|
||||
random_state=self._random_state, n_features=16, n_samples=10
|
||||
@ -49,6 +49,7 @@ class Odte_test(unittest.TestCase):
|
||||
computed = tclf._get_random_subspace(X, y, tclf.max_features_)
|
||||
expected = expected_values.pop(0)
|
||||
self.assertListEqual(expected, list(computed))
|
||||
# print(f"{list(computed)},")
|
||||
|
||||
def test_initialize_random(self):
|
||||
expected = [37, 235, 908]
|
||||
@ -128,11 +129,12 @@ class Odte_test(unittest.TestCase):
|
||||
def test_score_splitter_max_features(self):
|
||||
X, y = load_dataset(self._random_state, n_features=12, n_samples=150)
|
||||
results = [
|
||||
1.0,
|
||||
1.0,
|
||||
0.86,
|
||||
0.8933333333333333,
|
||||
0.9933333333333333,
|
||||
0.9933333333333333,
|
||||
]
|
||||
random.seed(self._random_state)
|
||||
for max_features in ["auto", None]:
|
||||
for splitter in ["best", "random"]:
|
||||
tclf = Odte(
|
||||
@ -143,12 +145,22 @@ class Odte_test(unittest.TestCase):
|
||||
tclf.set_params(
|
||||
**dict(
|
||||
base_estimator__splitter=splitter,
|
||||
base_estimator__random_state=self._random_state,
|
||||
)
|
||||
)
|
||||
expected = results.pop(0)
|
||||
computed = tclf.fit(X, y).score(X, y)
|
||||
# print(computed, splitter, max_features)
|
||||
self.assertAlmostEqual(expected, computed)
|
||||
|
||||
def test_generate_subspaces(self):
|
||||
features = 250
|
||||
for max_features in range(2, features):
|
||||
num = len(Odte._generate_spaces(features, max_features))
|
||||
self.assertEqual(5, num)
|
||||
self.assertEqual(3, len(Odte._generate_spaces(3, 2)))
|
||||
self.assertEqual(4, len(Odte._generate_spaces(4, 3)))
|
||||
|
||||
@staticmethod
|
||||
def test_is_a_sklearn_classifier():
|
||||
os.environ["PYTHONWARNINGS"] = "ignore"
|
||||
|
Loading…
x
Reference in New Issue
Block a user