mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 16:06:01 +00:00
@@ -131,6 +131,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
max_depth: int = None,
|
max_depth: int = None,
|
||||||
tol: float = 1e-4,
|
tol: float = 1e-4,
|
||||||
use_predictions: bool = False,
|
use_predictions: bool = False,
|
||||||
|
min_samples_split: int = 0,
|
||||||
):
|
):
|
||||||
self.max_iter = max_iter
|
self.max_iter = max_iter
|
||||||
self.C = C
|
self.C = C
|
||||||
@@ -138,6 +139,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
self.use_predictions = use_predictions
|
self.use_predictions = use_predictions
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
self.tol = tol
|
self.tol = tol
|
||||||
|
self.min_samples_split = min_samples_split
|
||||||
|
|
||||||
def _more_tags(self) -> dict:
|
def _more_tags(self) -> dict:
|
||||||
"""Required by sklearn to tell that this estimator is a binary classifier
|
"""Required by sklearn to tell that this estimator is a binary classifier
|
||||||
@@ -206,7 +208,11 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
:return: [description]
|
:return: [description]
|
||||||
:rtype: np.array
|
:rtype: np.array
|
||||||
"""
|
"""
|
||||||
return data > 0
|
return (
|
||||||
|
data > 0
|
||||||
|
if data.shape[0] >= self.min_samples_split
|
||||||
|
else np.ones((data.shape[0], 1), dtype=bool)
|
||||||
|
)
|
||||||
|
|
||||||
def fit(
|
def fit(
|
||||||
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
|
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
|
||||||
|
@@ -133,7 +133,6 @@ class Stree_grapher(Stree):
|
|||||||
os.environ.pop("TESTING")
|
os.environ.pop("TESTING")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
plt.close("all")
|
|
||||||
|
|
||||||
def _copy_tree(self, node: Snode) -> Snode_graph:
|
def _copy_tree(self, node: Snode) -> Snode_graph:
|
||||||
mirror = Snode_graph(node)
|
mirror = Snode_graph(node)
|
||||||
|
@@ -315,6 +315,17 @@ class Stree_test(unittest.TestCase):
|
|||||||
tcl = Stree()
|
tcl = Stree()
|
||||||
self.assertEqual(0, len(list(tcl)))
|
self.assertEqual(0, len(list(tcl)))
|
||||||
|
|
||||||
|
def test_min_samples_split(self):
|
||||||
|
tcl_split = Stree(min_samples_split=3)
|
||||||
|
tcl_nosplit = Stree(min_samples_split=4)
|
||||||
|
dataset = [[1], [2], [3]], [1, 1, 0]
|
||||||
|
tcl_split.fit(*dataset)
|
||||||
|
self.assertIsNotNone(tcl_split.tree_.get_down())
|
||||||
|
self.assertIsNotNone(tcl_split.tree_.get_up())
|
||||||
|
tcl_nosplit.fit(*dataset)
|
||||||
|
self.assertIsNone(tcl_nosplit.tree_.get_down())
|
||||||
|
self.assertIsNone(tcl_nosplit.tree_.get_up())
|
||||||
|
|
||||||
|
|
||||||
class Snode_test(unittest.TestCase):
|
class Snode_test(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
Reference in New Issue
Block a user