mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 08:12:06 +00:00
Solve combinatorial explosion
This commit is contained in:
parent
7a49d672df
commit
ca773d3537
@ -1,19 +1,19 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/ambv/black
|
- repo: https://github.com/ambv/black
|
||||||
rev: stable
|
rev: stable
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
language_version: python3.8
|
language_version: python3.8
|
||||||
- repo: https://gitlab.com/pycqa/flake8
|
- repo: https://gitlab.com/pycqa/flake8
|
||||||
rev: 3.8.3
|
rev: 3.8.3
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: '' # Use the sha / tag you want to point at
|
# rev: '' # Use the sha / tag you want to point at
|
||||||
hooks:
|
# hooks:
|
||||||
- id: mypy
|
# - id: mypy
|
||||||
args: [--strict, --ignore-missing-imports]
|
# args: [--strict, --ignore-missing-imports]
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: unittest
|
- id: unittest
|
||||||
name: unittest
|
name: unittest
|
||||||
|
22
odte/Odte.py
22
odte/Odte.py
@ -8,8 +8,8 @@ Build a forest of oblique trees based on STree
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import random
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
from math import factorial
|
||||||
from typing import Union, Optional, Tuple, List
|
from typing import Union, Optional, Tuple, List
|
||||||
from itertools import combinations
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.utils.multiclass import check_classification_targets
|
from sklearn.utils.multiclass import check_classification_targets
|
||||||
from sklearn.base import clone, BaseEstimator, ClassifierMixin
|
from sklearn.base import clone, BaseEstimator, ClassifierMixin
|
||||||
@ -189,12 +189,28 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
|
|||||||
)
|
)
|
||||||
return max_features
|
return max_features
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_spaces(features: int, max_features: int) -> list:
|
||||||
|
comb = set()
|
||||||
|
# Generate at most 5 combinations
|
||||||
|
if max_features == features:
|
||||||
|
set_length = 1
|
||||||
|
else:
|
||||||
|
number = factorial(features) / (
|
||||||
|
factorial(max_features) * factorial(features - max_features)
|
||||||
|
)
|
||||||
|
set_length = min(5, number)
|
||||||
|
while len(comb) < set_length:
|
||||||
|
comb.add(
|
||||||
|
tuple(sorted(random.sample(range(features), max_features)))
|
||||||
|
)
|
||||||
|
return list(comb)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_random_subspace(
|
def _get_random_subspace(
|
||||||
dataset: np.array, labels: np.array, max_features: int
|
dataset: np.array, labels: np.array, max_features: int
|
||||||
) -> Tuple[int, ...]:
|
) -> Tuple[int, ...]:
|
||||||
features = range(dataset.shape[1])
|
features_sets = Odte._generate_spaces(dataset.shape[1], max_features)
|
||||||
features_sets = list(combinations(features, max_features))
|
|
||||||
if len(features_sets) > 1:
|
if len(features_sets) > 1:
|
||||||
index = random.randint(0, len(features_sets) - 1)
|
index = random.randint(0, len(features_sets) - 1)
|
||||||
return features_sets[index]
|
return features_sets[index]
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# type: ignore
|
# type: ignore
|
||||||
import unittest
|
import unittest
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from sklearn.exceptions import ConvergenceWarning
|
from sklearn.exceptions import ConvergenceWarning
|
||||||
|
|
||||||
@ -30,13 +30,13 @@ class Odte_test(unittest.TestCase):
|
|||||||
|
|
||||||
def test_initialize_max_feature(self):
|
def test_initialize_max_feature(self):
|
||||||
expected_values = [
|
expected_values = [
|
||||||
[0, 5, 6, 15],
|
[6, 7, 8, 15],
|
||||||
[0, 2, 3, 9, 11, 14],
|
[3, 4, 5, 6, 10, 13],
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||||
[0, 5, 6, 15],
|
[6, 7, 8, 15],
|
||||||
[0, 5, 6, 15],
|
[6, 7, 8, 15],
|
||||||
[0, 5, 6, 15],
|
[6, 7, 8, 15],
|
||||||
]
|
]
|
||||||
X, y = load_dataset(
|
X, y = load_dataset(
|
||||||
random_state=self._random_state, n_features=16, n_samples=10
|
random_state=self._random_state, n_features=16, n_samples=10
|
||||||
@ -49,6 +49,7 @@ class Odte_test(unittest.TestCase):
|
|||||||
computed = tclf._get_random_subspace(X, y, tclf.max_features_)
|
computed = tclf._get_random_subspace(X, y, tclf.max_features_)
|
||||||
expected = expected_values.pop(0)
|
expected = expected_values.pop(0)
|
||||||
self.assertListEqual(expected, list(computed))
|
self.assertListEqual(expected, list(computed))
|
||||||
|
# print(f"{list(computed)},")
|
||||||
|
|
||||||
def test_initialize_random(self):
|
def test_initialize_random(self):
|
||||||
expected = [37, 235, 908]
|
expected = [37, 235, 908]
|
||||||
@ -128,11 +129,12 @@ class Odte_test(unittest.TestCase):
|
|||||||
def test_score_splitter_max_features(self):
|
def test_score_splitter_max_features(self):
|
||||||
X, y = load_dataset(self._random_state, n_features=12, n_samples=150)
|
X, y = load_dataset(self._random_state, n_features=12, n_samples=150)
|
||||||
results = [
|
results = [
|
||||||
1.0,
|
0.86,
|
||||||
1.0,
|
0.8933333333333333,
|
||||||
0.9933333333333333,
|
0.9933333333333333,
|
||||||
0.9933333333333333,
|
0.9933333333333333,
|
||||||
]
|
]
|
||||||
|
random.seed(self._random_state)
|
||||||
for max_features in ["auto", None]:
|
for max_features in ["auto", None]:
|
||||||
for splitter in ["best", "random"]:
|
for splitter in ["best", "random"]:
|
||||||
tclf = Odte(
|
tclf = Odte(
|
||||||
@ -143,12 +145,22 @@ class Odte_test(unittest.TestCase):
|
|||||||
tclf.set_params(
|
tclf.set_params(
|
||||||
**dict(
|
**dict(
|
||||||
base_estimator__splitter=splitter,
|
base_estimator__splitter=splitter,
|
||||||
|
base_estimator__random_state=self._random_state,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
expected = results.pop(0)
|
expected = results.pop(0)
|
||||||
computed = tclf.fit(X, y).score(X, y)
|
computed = tclf.fit(X, y).score(X, y)
|
||||||
|
# print(computed, splitter, max_features)
|
||||||
self.assertAlmostEqual(expected, computed)
|
self.assertAlmostEqual(expected, computed)
|
||||||
|
|
||||||
|
def test_generate_subspaces(self):
|
||||||
|
features = 250
|
||||||
|
for max_features in range(2, features):
|
||||||
|
num = len(Odte._generate_spaces(features, max_features))
|
||||||
|
self.assertEqual(5, num)
|
||||||
|
self.assertEqual(3, len(Odte._generate_spaces(3, 2)))
|
||||||
|
self.assertEqual(4, len(Odte._generate_spaces(4, 3)))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def test_is_a_sklearn_classifier():
|
def test_is_a_sklearn_classifier():
|
||||||
os.environ["PYTHONWARNINGS"] = "ignore"
|
os.environ["PYTHONWARNINGS"] = "ignore"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user