#1 Add min_samples_split

Fix #1
This commit is contained in:
2020-06-07 16:12:25 +02:00
parent 8ba9b1b6a1
commit b824229121
3 changed files with 18 additions and 2 deletions

View File

@@ -131,6 +131,7 @@ class Stree(BaseEstimator, ClassifierMixin):
max_depth: int = None,
tol: float = 1e-4,
use_predictions: bool = False,
min_samples_split: int = 0,
):
self.max_iter = max_iter
self.C = C
@@ -138,6 +139,7 @@ class Stree(BaseEstimator, ClassifierMixin):
self.use_predictions = use_predictions
self.max_depth = max_depth
self.tol = tol
self.min_samples_split = min_samples_split
def _more_tags(self) -> dict:
"""Required by sklearn to tell that this estimator is a binary classifier
@@ -206,7 +208,11 @@ class Stree(BaseEstimator, ClassifierMixin):
:return: [description]
:rtype: np.array
"""
return data > 0
return (
data > 0
if data.shape[0] >= self.min_samples_split
else np.ones((data.shape[0], 1), dtype=bool)
)
def fit(
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None

View File

@@ -133,7 +133,6 @@ class Stree_grapher(Stree):
os.environ.pop("TESTING")
except KeyError:
pass
plt.close("all")
def _copy_tree(self, node: Snode) -> Snode_graph:
mirror = Snode_graph(node)

View File

@@ -315,6 +315,17 @@ class Stree_test(unittest.TestCase):
tcl = Stree()
self.assertEqual(0, len(list(tcl)))
def test_min_samples_split(self):
tcl_split = Stree(min_samples_split=3)
tcl_nosplit = Stree(min_samples_split=4)
dataset = [[1], [2], [3]], [1, 1, 0]
tcl_split.fit(*dataset)
self.assertIsNotNone(tcl_split.tree_.get_down())
self.assertIsNotNone(tcl_split.tree_.get_up())
tcl_nosplit.fit(*dataset)
self.assertIsNone(tcl_nosplit.tree_.get_down())
self.assertIsNone(tcl_nosplit.tree_.get_up())
class Snode_test(unittest.TestCase):
def __init__(self, *args, **kwargs):