This commit is contained in:
Ricardo Montañana Gómez 2022-02-25 19:24:44 +01:00
parent 9e5fe8c791
commit 877c24f3f4
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE

View File

@ -26,35 +26,6 @@ from stree import Stree # type: ignore
from ._version import __version__ from ._version import __version__
def _parallel_build_tree(
base_estimator_: Stree,
X: np.ndarray,
y: np.ndarray,
weights: np.ndarray,
random_seed: int,
boot_samples: int,
max_features: int,
hyperparams: str,
) -> Tuple[BaseEstimator, Tuple[int, ...]]:
clf = base_estimator_
hyperparams_ = json.loads(hyperparams)
hyperparams_.update(dict(random_state=random_seed))
clf.set_params(**hyperparams_)
n_samples = X.shape[0]
# bootstrap
random_box = check_random_state(random_seed)
indices = random_box.randint(0, n_samples, boot_samples)
# update weights with the chosen samples
weights_update = np.bincount(indices, minlength=n_samples)
current_weights = weights * weights_update
# random subspace
features = Odte._get_random_subspace(X, y, max_features)
# train the classifier
bootstrap = X[indices, :]
clf.fit(bootstrap[:, features], y[indices], current_weights[indices])
return (clf, features)
class Odte(BaseEnsemble, ClassifierMixin): class Odte(BaseEnsemble, ClassifierMixin):
def __init__( def __init__(
self, self,
@ -135,15 +106,12 @@ class Odte(BaseEnsemble, ClassifierMixin):
def _train( def _train(
self, X: np.ndarray, y: np.ndarray, weights: np.ndarray self, X: np.ndarray, y: np.ndarray, weights: np.ndarray
) -> Tuple[List[BaseEstimator], List[Tuple[int, ...]]]: ) -> Tuple[List[BaseEstimator], List[Tuple[int, ...]]]:
# np.random.RandomState(seed)
n_samples = X.shape[0] n_samples = X.shape[0]
boot_samples = self._get_bootstrap_n_samples(n_samples) boot_samples = self._get_bootstrap_n_samples(n_samples)
estimator = [] estimator = clone(self.base_estimator_)
for i in range(self.n_estimators):
estimator.append(clone(self.base_estimator_))
return Parallel(n_jobs=self.n_jobs, prefer="threads")( # type: ignore return Parallel(n_jobs=self.n_jobs, prefer="threads")( # type: ignore
delayed(_parallel_build_tree)( delayed(Odte._parallel_build_tree)(
estimator[i], estimator,
X, X,
y, y,
weights, weights,
@ -152,11 +120,40 @@ class Odte(BaseEnsemble, ClassifierMixin):
self.max_features_, self.max_features_,
self.be_hyperparams, self.be_hyperparams,
) )
for i, random_seed in enumerate( for random_seed in range(
range(self.random_state, self.random_state + self.n_estimators) self.random_state, self.random_state + self.n_estimators
) )
) )
@staticmethod
def _parallel_build_tree(
base_estimator_: BaseEstimator,
X: np.ndarray,
y: np.ndarray,
weights: np.ndarray,
random_seed: int,
boot_samples: int,
max_features: int,
hyperparams: str,
) -> Tuple[BaseEstimator, Tuple[int, ...]]:
clf = clone(base_estimator_)
hyperparams_ = json.loads(hyperparams)
hyperparams_.update(dict(random_state=random_seed))
clf.set_params(**hyperparams_)
n_samples = X.shape[0]
# bootstrap
random_box = check_random_state(random_seed)
indices = random_box.randint(0, n_samples, boot_samples)
# update weights with the chosen samples
weights_update = np.bincount(indices, minlength=n_samples)
current_weights = weights * weights_update
# random subspace
features = Odte._get_random_subspace(X, y, max_features)
# train the classifier
bootstrap = X[indices, :]
clf.fit(bootstrap[:, features], y[indices], current_weights[indices])
return (clf, features)
def _get_bootstrap_n_samples(self, n_samples: int) -> int: def _get_bootstrap_n_samples(self, n_samples: int) -> int:
if self.max_samples is None: if self.max_samples is None:
return n_samples return n_samples