diff --git a/stree/Strees.py b/stree/Strees.py index dc95247..7c32d32 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -440,7 +440,12 @@ class Stree(BaseEstimator, ClassifierMixin): check_classification_targets(y) X, y = check_X_y(X, y) - sample_weight = _check_sample_weight(sample_weight, X) + sample_weight = _check_sample_weight( + sample_weight, X, dtype=np.float64 + ) + # solve WARNING: class label 0 specified in weight is not found + # in bagging + sample_weight += 1e-5 check_classification_targets(y) # Initialize computed parameters self.splitter_ = Splitter(