From 877c24f3f4e528c419f7ddbb706ea91a87ce7ee2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Fri, 25 Feb 2022 19:24:44 +0100 Subject: [PATCH] fix rc1 --- odte/Odte.py | 71 +++++++++++++++++++++++++--------------------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/odte/Odte.py b/odte/Odte.py index a396da7..8c9c059 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -26,35 +26,6 @@ from stree import Stree # type: ignore from ._version import __version__ -def _parallel_build_tree( - base_estimator_: Stree, - X: np.ndarray, - y: np.ndarray, - weights: np.ndarray, - random_seed: int, - boot_samples: int, - max_features: int, - hyperparams: str, -) -> Tuple[BaseEstimator, Tuple[int, ...]]: - clf = base_estimator_ - hyperparams_ = json.loads(hyperparams) - hyperparams_.update(dict(random_state=random_seed)) - clf.set_params(**hyperparams_) - n_samples = X.shape[0] - # bootstrap - random_box = check_random_state(random_seed) - indices = random_box.randint(0, n_samples, boot_samples) - # update weights with the chosen samples - weights_update = np.bincount(indices, minlength=n_samples) - current_weights = weights * weights_update - # random subspace - features = Odte._get_random_subspace(X, y, max_features) - # train the classifier - bootstrap = X[indices, :] - clf.fit(bootstrap[:, features], y[indices], current_weights[indices]) - return (clf, features) - - class Odte(BaseEnsemble, ClassifierMixin): def __init__( self, @@ -135,15 +106,12 @@ class Odte(BaseEnsemble, ClassifierMixin): def _train( self, X: np.ndarray, y: np.ndarray, weights: np.ndarray ) -> Tuple[List[BaseEstimator], List[Tuple[int, ...]]]: - # np.random.RandomState(seed) n_samples = X.shape[0] boot_samples = self._get_bootstrap_n_samples(n_samples) - estimator = [] - for i in range(self.n_estimators): - estimator.append(clone(self.base_estimator_)) + estimator = clone(self.base_estimator_) return Parallel(n_jobs=self.n_jobs, prefer="threads")( # type: ignore - delayed(_parallel_build_tree)( - estimator[i], + delayed(Odte._parallel_build_tree)( + estimator, X, y, weights, @@ -152,11 +120,40 @@ class Odte(BaseEnsemble, ClassifierMixin): self.max_features_, self.be_hyperparams, ) - for i, random_seed in enumerate( - range(self.random_state, self.random_state + self.n_estimators) + for random_seed in range( + self.random_state, self.random_state + self.n_estimators ) ) + @staticmethod + def _parallel_build_tree( + base_estimator_: BaseEstimator, + X: np.ndarray, + y: np.ndarray, + weights: np.ndarray, + random_seed: int, + boot_samples: int, + max_features: int, + hyperparams: str, + ) -> Tuple[BaseEstimator, Tuple[int, ...]]: + clf = clone(base_estimator_) + hyperparams_ = json.loads(hyperparams) + hyperparams_.update(dict(random_state=random_seed)) + clf.set_params(**hyperparams_) + n_samples = X.shape[0] + # bootstrap + random_box = check_random_state(random_seed) + indices = random_box.randint(0, n_samples, boot_samples) + # update weights with the chosen samples + weights_update = np.bincount(indices, minlength=n_samples) + current_weights = weights * weights_update + # random subspace + features = Odte._get_random_subspace(X, y, max_features) + # train the classifier + bootstrap = X[indices, :] + clf.fit(bootstrap[:, features], y[indices], current_weights[indices]) + return (clf, features) + def _get_bootstrap_n_samples(self, n_samples: int) -> int: if self.max_samples is None: return n_samples