mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 08:12:06 +00:00
fix rc1
This commit is contained in:
parent
9e5fe8c791
commit
877c24f3f4
71
odte/Odte.py
71
odte/Odte.py
@ -26,35 +26,6 @@ from stree import Stree # type: ignore
|
||||
from ._version import __version__
|
||||
|
||||
|
||||
def _parallel_build_tree(
|
||||
base_estimator_: Stree,
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
weights: np.ndarray,
|
||||
random_seed: int,
|
||||
boot_samples: int,
|
||||
max_features: int,
|
||||
hyperparams: str,
|
||||
) -> Tuple[BaseEstimator, Tuple[int, ...]]:
|
||||
clf = base_estimator_
|
||||
hyperparams_ = json.loads(hyperparams)
|
||||
hyperparams_.update(dict(random_state=random_seed))
|
||||
clf.set_params(**hyperparams_)
|
||||
n_samples = X.shape[0]
|
||||
# bootstrap
|
||||
random_box = check_random_state(random_seed)
|
||||
indices = random_box.randint(0, n_samples, boot_samples)
|
||||
# update weights with the chosen samples
|
||||
weights_update = np.bincount(indices, minlength=n_samples)
|
||||
current_weights = weights * weights_update
|
||||
# random subspace
|
||||
features = Odte._get_random_subspace(X, y, max_features)
|
||||
# train the classifier
|
||||
bootstrap = X[indices, :]
|
||||
clf.fit(bootstrap[:, features], y[indices], current_weights[indices])
|
||||
return (clf, features)
|
||||
|
||||
|
||||
class Odte(BaseEnsemble, ClassifierMixin):
|
||||
def __init__(
|
||||
self,
|
||||
@ -135,15 +106,12 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
||||
def _train(
|
||||
self, X: np.ndarray, y: np.ndarray, weights: np.ndarray
|
||||
) -> Tuple[List[BaseEstimator], List[Tuple[int, ...]]]:
|
||||
# np.random.RandomState(seed)
|
||||
n_samples = X.shape[0]
|
||||
boot_samples = self._get_bootstrap_n_samples(n_samples)
|
||||
estimator = []
|
||||
for i in range(self.n_estimators):
|
||||
estimator.append(clone(self.base_estimator_))
|
||||
estimator = clone(self.base_estimator_)
|
||||
return Parallel(n_jobs=self.n_jobs, prefer="threads")( # type: ignore
|
||||
delayed(_parallel_build_tree)(
|
||||
estimator[i],
|
||||
delayed(Odte._parallel_build_tree)(
|
||||
estimator,
|
||||
X,
|
||||
y,
|
||||
weights,
|
||||
@ -152,11 +120,40 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
||||
self.max_features_,
|
||||
self.be_hyperparams,
|
||||
)
|
||||
for i, random_seed in enumerate(
|
||||
range(self.random_state, self.random_state + self.n_estimators)
|
||||
for random_seed in range(
|
||||
self.random_state, self.random_state + self.n_estimators
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parallel_build_tree(
|
||||
base_estimator_: BaseEstimator,
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
weights: np.ndarray,
|
||||
random_seed: int,
|
||||
boot_samples: int,
|
||||
max_features: int,
|
||||
hyperparams: str,
|
||||
) -> Tuple[BaseEstimator, Tuple[int, ...]]:
|
||||
clf = clone(base_estimator_)
|
||||
hyperparams_ = json.loads(hyperparams)
|
||||
hyperparams_.update(dict(random_state=random_seed))
|
||||
clf.set_params(**hyperparams_)
|
||||
n_samples = X.shape[0]
|
||||
# bootstrap
|
||||
random_box = check_random_state(random_seed)
|
||||
indices = random_box.randint(0, n_samples, boot_samples)
|
||||
# update weights with the chosen samples
|
||||
weights_update = np.bincount(indices, minlength=n_samples)
|
||||
current_weights = weights * weights_update
|
||||
# random subspace
|
||||
features = Odte._get_random_subspace(X, y, max_features)
|
||||
# train the classifier
|
||||
bootstrap = X[indices, :]
|
||||
clf.fit(bootstrap[:, features], y[indices], current_weights[indices])
|
||||
return (clf, features)
|
||||
|
||||
def _get_bootstrap_n_samples(self, n_samples: int) -> int:
|
||||
if self.max_samples is None:
|
||||
return n_samples
|
||||
|
Loading…
x
Reference in New Issue
Block a user