Merge pull request #8 from Doctorado-ML/fix_python_random_init

Fix python random init
This commit is contained in:
Ricardo Montañana Gómez 2022-04-29 10:22:33 +02:00 committed by GitHub
commit 7300bd66db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 12 deletions

View File

@ -1,23 +1,23 @@
repos: repos:
- repo: https://github.com/ambv/black - repo: https://github.com/ambv/black
rev: 20.8b1 rev: 22.3.0
hooks: hooks:
- id: black - id: black
exclude: ".virtual_documents" exclude: ".virtual_documents"
language_version: python3.9 language_version: python3.9
- repo: https://gitlab.com/pycqa/flake8 - repo: https://gitlab.com/pycqa/flake8
rev: 3.8.4 rev: 3.9.2
hooks: hooks:
- id: flake8 - id: flake8
exclude: ".virtual_documents" exclude: ".virtual_documents"
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: "v0.790" # Use the sha / tag you want to point at rev: "v0.942" # Use the sha / tag you want to point at
hooks: hooks:
- id: mypy - id: mypy
#args: [--strict, --ignore-missing-imports] #args: [--strict, --ignore-missing-imports]
exclude: odte/tests exclude: odte/tests
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.4.0 rev: v4.2.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: check-case-conflict - id: check-case-conflict

View File

@ -15,7 +15,7 @@ from sklearn.utils.multiclass import ( # type: ignore
check_classification_targets, check_classification_targets,
) )
from sklearn.base import clone, BaseEstimator, ClassifierMixin # type: ignore from sklearn.base import clone, BaseEstimator, ClassifierMixin # type: ignore
from sklearn.utils import check_random_state from sklearn.utils import check_random_state # type: ignore
from sklearn.ensemble import BaseEnsemble # type: ignore from sklearn.ensemble import BaseEnsemble # type: ignore
from sklearn.utils.validation import ( # type: ignore from sklearn.utils.validation import ( # type: ignore
check_is_fitted, check_is_fitted,
@ -30,7 +30,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
def __init__( def __init__(
self, self,
# n_jobs = -1 to use all available cores # n_jobs = -1 to use all available cores
n_jobs: int = 1, n_jobs: int = -1,
base_estimator: BaseEstimator = None, base_estimator: BaseEstimator = None,
random_state: int = 0, random_state: int = 0,
max_features: Optional[Union[str, int, float]] = None, max_features: Optional[Union[str, int, float]] = None,
@ -141,8 +141,10 @@ class Odte(BaseEnsemble, ClassifierMixin):
hyperparams_.update(dict(random_state=random_seed)) hyperparams_.update(dict(random_state=random_seed))
clf.set_params(**hyperparams_) clf.set_params(**hyperparams_)
n_samples = X.shape[0] n_samples = X.shape[0]
# bootstrap # initialize random boxes
random.seed(random_seed)
random_box = check_random_state(random_seed) random_box = check_random_state(random_seed)
# bootstrap
indices = random_box.randint(0, n_samples, boot_samples) indices = random_box.randint(0, n_samples, boot_samples)
# update weights with the chosen samples # update weights with the chosen samples
weights_update = np.bincount(indices, minlength=n_samples) weights_update = np.bincount(indices, minlength=n_samples)

View File

@ -1 +1 @@
__version__ = "0.3.2" __version__ = "0.3.3"

View File

@ -1,7 +1,6 @@
# type: ignore # type: ignore
import unittest import unittest
import os import os
import random
import warnings import warnings
import json import json
from sklearn.exceptions import ConvergenceWarning, NotFittedError from sklearn.exceptions import ConvergenceWarning, NotFittedError
@ -46,7 +45,9 @@ class Odte_test(unittest.TestCase):
) )
for max_features in [4, 0.4, 1.0, None, "auto", "sqrt", "log2"]: for max_features in [4, 0.4, 1.0, None, "auto", "sqrt", "log2"]:
tclf = Odte( tclf = Odte(
random_state=self._random_state, max_features=max_features random_state=self._random_state,
max_features=max_features,
n_jobs=1,
) )
tclf.fit(X, y) tclf.fit(X, y)
computed = tclf._get_random_subspace(X, y, tclf.max_features_) computed = tclf._get_random_subspace(X, y, tclf.max_features_)
@ -135,7 +136,6 @@ class Odte_test(unittest.TestCase):
0.97, # iwss None 0.97, # iwss None
0.97, # cfs None 0.97, # cfs None
] ]
random.seed(self._random_state)
for max_features in ["auto", None]: for max_features in ["auto", None]:
for splitter in [ for splitter in [
"best", "best",
@ -149,6 +149,7 @@ class Odte_test(unittest.TestCase):
base_estimator=Stree(), base_estimator=Stree(),
random_state=self._random_state, random_state=self._random_state,
n_estimators=3, n_estimators=3,
n_jobs=1,
) )
tclf.set_params( tclf.set_params(
**dict( **dict(
@ -160,7 +161,7 @@ class Odte_test(unittest.TestCase):
expected = results.pop(0) expected = results.pop(0)
computed = tclf.fit(X, y).score(X, y) computed = tclf.fit(X, y).score(X, y)
# print(computed, splitter, max_features) # print(computed, splitter, max_features)
self.assertAlmostEqual(expected, computed) self.assertAlmostEqual(expected, computed, msg=splitter)
def test_generate_subspaces(self): def test_generate_subspaces(self):
features = 250 features = 250