From e01ca43cf9af1b4b48cef4e8f4c0555e72715ad5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 10 Mar 2022 13:17:56 +0100 Subject: [PATCH 1/3] Fix python random init --- odte/Odte.py | 6 ++++-- odte/tests/Odte_tests.py | 7 +++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/odte/Odte.py b/odte/Odte.py index 8c9c059..1425aba 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -30,7 +30,7 @@ class Odte(BaseEnsemble, ClassifierMixin): def __init__( self, # n_jobs = -1 to use all available cores - n_jobs: int = 1, + n_jobs: int = -1, base_estimator: BaseEstimator = None, random_state: int = 0, max_features: Optional[Union[str, int, float]] = None, @@ -141,8 +141,10 @@ class Odte(BaseEnsemble, ClassifierMixin): hyperparams_.update(dict(random_state=random_seed)) clf.set_params(**hyperparams_) n_samples = X.shape[0] - # bootstrap + # initialize random boxes + random.seed(random_seed) random_box = check_random_state(random_seed) + # bootstrap indices = random_box.randint(0, n_samples, boot_samples) # update weights with the chosen samples weights_update = np.bincount(indices, minlength=n_samples) diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index 01ca3bc..0d1aa86 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -46,7 +46,9 @@ class Odte_test(unittest.TestCase): ) for max_features in [4, 0.4, 1.0, None, "auto", "sqrt", "log2"]: tclf = Odte( - random_state=self._random_state, max_features=max_features + random_state=self._random_state, + max_features=max_features, + n_jobs=1, ) tclf.fit(X, y) computed = tclf._get_random_subspace(X, y, tclf.max_features_) @@ -149,6 +151,7 @@ class Odte_test(unittest.TestCase): base_estimator=Stree(), random_state=self._random_state, n_estimators=3, + n_jobs=1, ) tclf.set_params( **dict( @@ -160,7 +163,7 @@ class Odte_test(unittest.TestCase): expected = results.pop(0) computed = tclf.fit(X, y).score(X, y) # print(computed, splitter, max_features) - self.assertAlmostEqual(expected, computed) + self.assertAlmostEqual(expected, computed, msg=splitter) def test_generate_subspaces(self): features = 250 From 267a17a7084fa3c6df07142a0e40111f8481bc8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Wed, 20 Apr 2022 11:25:45 +0200 Subject: [PATCH 2/3] Remove unneeded Random module from tests Update pre-commit config --- .pre-commit-config.yaml | 8 ++++---- odte/tests/Odte_tests.py | 2 -- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e829189..9497593 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,23 +1,23 @@ repos: - repo: https://github.com/ambv/black - rev: 20.8b1 + rev: 22.3.0 hooks: - id: black exclude: ".virtual_documents" language_version: python3.9 - repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.4 + rev: 3.9.2 hooks: - id: flake8 exclude: ".virtual_documents" - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v0.790" # Use the sha / tag you want to point at + rev: "v0.942" # Use the sha / tag you want to point at hooks: - id: mypy #args: [--strict, --ignore-missing-imports] exclude: odte/tests - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.4.0 + rev: v4.2.0 hooks: - id: trailing-whitespace - id: check-case-conflict diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index 0d1aa86..3974b45 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -1,7 +1,6 @@ # type: ignore import unittest import os -import random import warnings import json from sklearn.exceptions import ConvergenceWarning, NotFittedError @@ -137,7 +136,6 @@ class Odte_test(unittest.TestCase): 0.97, # iwss None 0.97, # cfs None ] - random.seed(self._random_state) for max_features in ["auto", None]: for splitter in [ "best", From 114f53d5e8c525a4fc9b86db9ef3048127e00661 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Fri, 29 Apr 2022 10:07:05 +0200 Subject: [PATCH 3/3] Update version file --- odte/Odte.py | 2 +- odte/_version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/odte/Odte.py b/odte/Odte.py index 1425aba..825b6c3 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -15,7 +15,7 @@ from sklearn.utils.multiclass import ( # type: ignore check_classification_targets, ) from sklearn.base import clone, BaseEstimator, ClassifierMixin # type: ignore -from sklearn.utils import check_random_state +from sklearn.utils import check_random_state # type: ignore from sklearn.ensemble import BaseEnsemble # type: ignore from sklearn.utils.validation import ( # type: ignore check_is_fitted, diff --git a/odte/_version.py b/odte/_version.py index f9aa3e1..e19434e 100644 --- a/odte/_version.py +++ b/odte/_version.py @@ -1 +1 @@ -__version__ = "0.3.2" +__version__ = "0.3.3"