Fix python random init

This commit is contained in:
Ricardo Montañana Gómez 2022-03-10 13:17:56 +01:00
parent 98cadc7eeb
commit e01ca43cf9
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 9 additions and 4 deletions

View File

@ -30,7 +30,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
def __init__(
self,
# n_jobs = -1 to use all available cores
n_jobs: int = 1,
n_jobs: int = -1,
base_estimator: BaseEstimator = None,
random_state: int = 0,
max_features: Optional[Union[str, int, float]] = None,
@ -141,8 +141,10 @@ class Odte(BaseEnsemble, ClassifierMixin):
hyperparams_.update(dict(random_state=random_seed))
clf.set_params(**hyperparams_)
n_samples = X.shape[0]
# bootstrap
# initialize random boxes
random.seed(random_seed)
random_box = check_random_state(random_seed)
# bootstrap
indices = random_box.randint(0, n_samples, boot_samples)
# update weights with the chosen samples
weights_update = np.bincount(indices, minlength=n_samples)

View File

@ -46,7 +46,9 @@ class Odte_test(unittest.TestCase):
)
for max_features in [4, 0.4, 1.0, None, "auto", "sqrt", "log2"]:
tclf = Odte(
random_state=self._random_state, max_features=max_features
random_state=self._random_state,
max_features=max_features,
n_jobs=1,
)
tclf.fit(X, y)
computed = tclf._get_random_subspace(X, y, tclf.max_features_)
@ -149,6 +151,7 @@ class Odte_test(unittest.TestCase):
base_estimator=Stree(),
random_state=self._random_state,
n_estimators=3,
n_jobs=1,
)
tclf.set_params(
**dict(
@ -160,7 +163,7 @@ class Odte_test(unittest.TestCase):
expected = results.pop(0)
computed = tclf.fit(X, y).score(X, y)
# print(computed, splitter, max_features)
self.assertAlmostEqual(expected, computed)
self.assertAlmostEqual(expected, computed, msg=splitter)
def test_generate_subspaces(self):
features = 250