#1 Add min_samples_split

Fix #1
This commit is contained in:
2020-06-07 16:12:25 +02:00
parent 8ba9b1b6a1
commit b824229121
3 changed files with 18 additions and 2 deletions

View File

@@ -131,6 +131,7 @@ class Stree(BaseEstimator, ClassifierMixin):
max_depth: int = None,
tol: float = 1e-4,
use_predictions: bool = False,
min_samples_split: int = 0,
):
self.max_iter = max_iter
self.C = C
@@ -138,6 +139,7 @@ class Stree(BaseEstimator, ClassifierMixin):
self.use_predictions = use_predictions
self.max_depth = max_depth
self.tol = tol
self.min_samples_split = min_samples_split
def _more_tags(self) -> dict:
"""Required by sklearn to tell that this estimator is a binary classifier
@@ -206,7 +208,11 @@ class Stree(BaseEstimator, ClassifierMixin):
:return: [description]
:rtype: np.array
"""
return data > 0
return (
data > 0
if data.shape[0] >= self.min_samples_split
else np.ones((data.shape[0], 1), dtype=bool)
)
def fit(
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None