mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 16:22:00 +00:00
Merge pull request #8 from Doctorado-ML/fix_python_random_init
Fix python random init
This commit is contained in:
commit
7300bd66db
@ -1,23 +1,23 @@
|
|||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/ambv/black
|
- repo: https://github.com/ambv/black
|
||||||
rev: 20.8b1
|
rev: 22.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
exclude: ".virtual_documents"
|
exclude: ".virtual_documents"
|
||||||
language_version: python3.9
|
language_version: python3.9
|
||||||
- repo: https://gitlab.com/pycqa/flake8
|
- repo: https://gitlab.com/pycqa/flake8
|
||||||
rev: 3.8.4
|
rev: 3.9.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
exclude: ".virtual_documents"
|
exclude: ".virtual_documents"
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: "v0.790" # Use the sha / tag you want to point at
|
rev: "v0.942" # Use the sha / tag you want to point at
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
#args: [--strict, --ignore-missing-imports]
|
#args: [--strict, --ignore-missing-imports]
|
||||||
exclude: odte/tests
|
exclude: odte/tests
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v3.4.0
|
rev: v4.2.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: check-case-conflict
|
- id: check-case-conflict
|
||||||
|
@ -15,7 +15,7 @@ from sklearn.utils.multiclass import ( # type: ignore
|
|||||||
check_classification_targets,
|
check_classification_targets,
|
||||||
)
|
)
|
||||||
from sklearn.base import clone, BaseEstimator, ClassifierMixin # type: ignore
|
from sklearn.base import clone, BaseEstimator, ClassifierMixin # type: ignore
|
||||||
from sklearn.utils import check_random_state
|
from sklearn.utils import check_random_state # type: ignore
|
||||||
from sklearn.ensemble import BaseEnsemble # type: ignore
|
from sklearn.ensemble import BaseEnsemble # type: ignore
|
||||||
from sklearn.utils.validation import ( # type: ignore
|
from sklearn.utils.validation import ( # type: ignore
|
||||||
check_is_fitted,
|
check_is_fitted,
|
||||||
@ -30,7 +30,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
# n_jobs = -1 to use all available cores
|
# n_jobs = -1 to use all available cores
|
||||||
n_jobs: int = 1,
|
n_jobs: int = -1,
|
||||||
base_estimator: BaseEstimator = None,
|
base_estimator: BaseEstimator = None,
|
||||||
random_state: int = 0,
|
random_state: int = 0,
|
||||||
max_features: Optional[Union[str, int, float]] = None,
|
max_features: Optional[Union[str, int, float]] = None,
|
||||||
@ -141,8 +141,10 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
|||||||
hyperparams_.update(dict(random_state=random_seed))
|
hyperparams_.update(dict(random_state=random_seed))
|
||||||
clf.set_params(**hyperparams_)
|
clf.set_params(**hyperparams_)
|
||||||
n_samples = X.shape[0]
|
n_samples = X.shape[0]
|
||||||
# bootstrap
|
# initialize random boxes
|
||||||
|
random.seed(random_seed)
|
||||||
random_box = check_random_state(random_seed)
|
random_box = check_random_state(random_seed)
|
||||||
|
# bootstrap
|
||||||
indices = random_box.randint(0, n_samples, boot_samples)
|
indices = random_box.randint(0, n_samples, boot_samples)
|
||||||
# update weights with the chosen samples
|
# update weights with the chosen samples
|
||||||
weights_update = np.bincount(indices, minlength=n_samples)
|
weights_update = np.bincount(indices, minlength=n_samples)
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = "0.3.2"
|
__version__ = "0.3.3"
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# type: ignore
|
# type: ignore
|
||||||
import unittest
|
import unittest
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import warnings
|
import warnings
|
||||||
import json
|
import json
|
||||||
from sklearn.exceptions import ConvergenceWarning, NotFittedError
|
from sklearn.exceptions import ConvergenceWarning, NotFittedError
|
||||||
@ -46,7 +45,9 @@ class Odte_test(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
for max_features in [4, 0.4, 1.0, None, "auto", "sqrt", "log2"]:
|
for max_features in [4, 0.4, 1.0, None, "auto", "sqrt", "log2"]:
|
||||||
tclf = Odte(
|
tclf = Odte(
|
||||||
random_state=self._random_state, max_features=max_features
|
random_state=self._random_state,
|
||||||
|
max_features=max_features,
|
||||||
|
n_jobs=1,
|
||||||
)
|
)
|
||||||
tclf.fit(X, y)
|
tclf.fit(X, y)
|
||||||
computed = tclf._get_random_subspace(X, y, tclf.max_features_)
|
computed = tclf._get_random_subspace(X, y, tclf.max_features_)
|
||||||
@ -135,7 +136,6 @@ class Odte_test(unittest.TestCase):
|
|||||||
0.97, # iwss None
|
0.97, # iwss None
|
||||||
0.97, # cfs None
|
0.97, # cfs None
|
||||||
]
|
]
|
||||||
random.seed(self._random_state)
|
|
||||||
for max_features in ["auto", None]:
|
for max_features in ["auto", None]:
|
||||||
for splitter in [
|
for splitter in [
|
||||||
"best",
|
"best",
|
||||||
@ -149,6 +149,7 @@ class Odte_test(unittest.TestCase):
|
|||||||
base_estimator=Stree(),
|
base_estimator=Stree(),
|
||||||
random_state=self._random_state,
|
random_state=self._random_state,
|
||||||
n_estimators=3,
|
n_estimators=3,
|
||||||
|
n_jobs=1,
|
||||||
)
|
)
|
||||||
tclf.set_params(
|
tclf.set_params(
|
||||||
**dict(
|
**dict(
|
||||||
@ -160,7 +161,7 @@ class Odte_test(unittest.TestCase):
|
|||||||
expected = results.pop(0)
|
expected = results.pop(0)
|
||||||
computed = tclf.fit(X, y).score(X, y)
|
computed = tclf.fit(X, y).score(X, y)
|
||||||
# print(computed, splitter, max_features)
|
# print(computed, splitter, max_features)
|
||||||
self.assertAlmostEqual(expected, computed)
|
self.assertAlmostEqual(expected, computed, msg=splitter)
|
||||||
|
|
||||||
def test_generate_subspaces(self):
|
def test_generate_subspaces(self):
|
||||||
features = 250
|
features = 250
|
||||||
|
Loading…
x
Reference in New Issue
Block a user