From b82422912138030f114d3ba07f48f3741c44b41d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sun, 7 Jun 2020 16:12:25 +0200 Subject: [PATCH] #1 Add min_samples_split Fix #1 --- stree/Strees.py | 8 +++++++- stree/Strees_grapher.py | 1 - stree/tests/Strees_test.py | 11 +++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/stree/Strees.py b/stree/Strees.py index 0e8b4cb..0992683 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -131,6 +131,7 @@ class Stree(BaseEstimator, ClassifierMixin): max_depth: int = None, tol: float = 1e-4, use_predictions: bool = False, + min_samples_split: int = 0, ): self.max_iter = max_iter self.C = C @@ -138,6 +139,7 @@ class Stree(BaseEstimator, ClassifierMixin): self.use_predictions = use_predictions self.max_depth = max_depth self.tol = tol + self.min_samples_split = min_samples_split def _more_tags(self) -> dict: """Required by sklearn to tell that this estimator is a binary classifier @@ -206,7 +208,11 @@ class Stree(BaseEstimator, ClassifierMixin): :return: [description] :rtype: np.array """ - return data > 0 + return ( + data > 0 + if data.shape[0] >= self.min_samples_split + else np.ones((data.shape[0], 1), dtype=bool) + ) def fit( self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None diff --git a/stree/Strees_grapher.py b/stree/Strees_grapher.py index 12b2b47..ba4fc4e 100644 --- a/stree/Strees_grapher.py +++ b/stree/Strees_grapher.py @@ -133,7 +133,6 @@ class Stree_grapher(Stree): os.environ.pop("TESTING") except KeyError: pass - plt.close("all") def _copy_tree(self, node: Snode) -> Snode_graph: mirror = Snode_graph(node) diff --git a/stree/tests/Strees_test.py b/stree/tests/Strees_test.py index 73dff81..52bb3ab 100644 --- a/stree/tests/Strees_test.py +++ b/stree/tests/Strees_test.py @@ -315,6 +315,17 @@ class Stree_test(unittest.TestCase): tcl = Stree() self.assertEqual(0, len(list(tcl))) + def test_min_samples_split(self): + tcl_split = Stree(min_samples_split=3) + tcl_nosplit = Stree(min_samples_split=4) + dataset = [[1], [2], [3]], [1, 1, 0] + tcl_split.fit(*dataset) + self.assertIsNotNone(tcl_split.tree_.get_down()) + self.assertIsNotNone(tcl_split.tree_.get_up()) + tcl_nosplit.fit(*dataset) + self.assertIsNone(tcl_nosplit.tree_.get_down()) + self.assertIsNone(tcl_nosplit.tree_.get_up()) + class Snode_test(unittest.TestCase): def __init__(self, *args, **kwargs):