mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
@@ -131,6 +131,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
max_depth: int = None,
|
||||
tol: float = 1e-4,
|
||||
use_predictions: bool = False,
|
||||
min_samples_split: int = 0,
|
||||
):
|
||||
self.max_iter = max_iter
|
||||
self.C = C
|
||||
@@ -138,6 +139,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.use_predictions = use_predictions
|
||||
self.max_depth = max_depth
|
||||
self.tol = tol
|
||||
self.min_samples_split = min_samples_split
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
"""Required by sklearn to tell that this estimator is a binary classifier
|
||||
@@ -206,7 +208,11 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
:return: [description]
|
||||
:rtype: np.array
|
||||
"""
|
||||
return data > 0
|
||||
return (
|
||||
data > 0
|
||||
if data.shape[0] >= self.min_samples_split
|
||||
else np.ones((data.shape[0], 1), dtype=bool)
|
||||
)
|
||||
|
||||
def fit(
|
||||
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
|
||||
|
@@ -133,7 +133,6 @@ class Stree_grapher(Stree):
|
||||
os.environ.pop("TESTING")
|
||||
except KeyError:
|
||||
pass
|
||||
plt.close("all")
|
||||
|
||||
def _copy_tree(self, node: Snode) -> Snode_graph:
|
||||
mirror = Snode_graph(node)
|
||||
|
@@ -315,6 +315,17 @@ class Stree_test(unittest.TestCase):
|
||||
tcl = Stree()
|
||||
self.assertEqual(0, len(list(tcl)))
|
||||
|
||||
def test_min_samples_split(self):
|
||||
tcl_split = Stree(min_samples_split=3)
|
||||
tcl_nosplit = Stree(min_samples_split=4)
|
||||
dataset = [[1], [2], [3]], [1, 1, 0]
|
||||
tcl_split.fit(*dataset)
|
||||
self.assertIsNotNone(tcl_split.tree_.get_down())
|
||||
self.assertIsNotNone(tcl_split.tree_.get_up())
|
||||
tcl_nosplit.fit(*dataset)
|
||||
self.assertIsNone(tcl_nosplit.tree_.get_down())
|
||||
self.assertIsNone(tcl_nosplit.tree_.get_up())
|
||||
|
||||
|
||||
class Snode_test(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
Reference in New Issue
Block a user