mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-19 01:16:00 +00:00
@@ -131,6 +131,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
max_depth: int = None,
|
||||
tol: float = 1e-4,
|
||||
use_predictions: bool = False,
|
||||
min_samples_split: int = 0,
|
||||
):
|
||||
self.max_iter = max_iter
|
||||
self.C = C
|
||||
@@ -138,6 +139,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.use_predictions = use_predictions
|
||||
self.max_depth = max_depth
|
||||
self.tol = tol
|
||||
self.min_samples_split = min_samples_split
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
"""Required by sklearn to tell that this estimator is a binary classifier
|
||||
@@ -206,7 +208,11 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
:return: [description]
|
||||
:rtype: np.array
|
||||
"""
|
||||
return data > 0
|
||||
return (
|
||||
data > 0
|
||||
if data.shape[0] >= self.min_samples_split
|
||||
else np.ones((data.shape[0], 1), dtype=bool)
|
||||
)
|
||||
|
||||
def fit(
|
||||
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
|
||||
|
Reference in New Issue
Block a user