Fix python random init

This commit is contained in:
Ricardo Montañana Gómez 2022-03-10 13:17:56 +01:00
parent 98cadc7eeb
commit e01ca43cf9
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 9 additions and 4 deletions

View File

@ -30,7 +30,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
def __init__( def __init__(
self, self,
# n_jobs = -1 to use all available cores # n_jobs = -1 to use all available cores
n_jobs: int = 1, n_jobs: int = -1,
base_estimator: BaseEstimator = None, base_estimator: BaseEstimator = None,
random_state: int = 0, random_state: int = 0,
max_features: Optional[Union[str, int, float]] = None, max_features: Optional[Union[str, int, float]] = None,
@ -141,8 +141,10 @@ class Odte(BaseEnsemble, ClassifierMixin):
hyperparams_.update(dict(random_state=random_seed)) hyperparams_.update(dict(random_state=random_seed))
clf.set_params(**hyperparams_) clf.set_params(**hyperparams_)
n_samples = X.shape[0] n_samples = X.shape[0]
# bootstrap # initialize random boxes
random.seed(random_seed)
random_box = check_random_state(random_seed) random_box = check_random_state(random_seed)
# bootstrap
indices = random_box.randint(0, n_samples, boot_samples) indices = random_box.randint(0, n_samples, boot_samples)
# update weights with the chosen samples # update weights with the chosen samples
weights_update = np.bincount(indices, minlength=n_samples) weights_update = np.bincount(indices, minlength=n_samples)

View File

@ -46,7 +46,9 @@ class Odte_test(unittest.TestCase):
) )
for max_features in [4, 0.4, 1.0, None, "auto", "sqrt", "log2"]: for max_features in [4, 0.4, 1.0, None, "auto", "sqrt", "log2"]:
tclf = Odte( tclf = Odte(
random_state=self._random_state, max_features=max_features random_state=self._random_state,
max_features=max_features,
n_jobs=1,
) )
tclf.fit(X, y) tclf.fit(X, y)
computed = tclf._get_random_subspace(X, y, tclf.max_features_) computed = tclf._get_random_subspace(X, y, tclf.max_features_)
@ -149,6 +151,7 @@ class Odte_test(unittest.TestCase):
base_estimator=Stree(), base_estimator=Stree(),
random_state=self._random_state, random_state=self._random_state,
n_estimators=3, n_estimators=3,
n_jobs=1,
) )
tclf.set_params( tclf.set_params(
**dict( **dict(
@ -160,7 +163,7 @@ class Odte_test(unittest.TestCase):
expected = results.pop(0) expected = results.pop(0)
computed = tclf.fit(X, y).score(X, y) computed = tclf.fit(X, y).score(X, y)
# print(computed, splitter, max_features) # print(computed, splitter, max_features)
self.assertAlmostEqual(expected, computed) self.assertAlmostEqual(expected, computed, msg=splitter)
def test_generate_subspaces(self): def test_generate_subspaces(self):
features = 250 features = 250