mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
Add a test Fix #27
This commit is contained in:
committed by
GitHub
parent
460c63a6d0
commit
6ba973dfe1
@@ -818,6 +818,23 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
score = y_true == y_pred
|
||||
return _weighted_sum(score, sample_weight, normalize=True)
|
||||
|
||||
def nodes_leaves(self) -> tuple:
|
||||
"""Compute the number of nodes and leaves in the built tree
|
||||
|
||||
Returns
|
||||
-------
|
||||
[tuple]
|
||||
tuple with the number of nodes and the number of leaves
|
||||
"""
|
||||
nodes = 0
|
||||
leaves = 0
|
||||
for node in self:
|
||||
if node.is_leaf():
|
||||
leaves += 1
|
||||
else:
|
||||
nodes += 1
|
||||
return nodes, leaves
|
||||
|
||||
def __iter__(self) -> Siterator:
|
||||
"""Create an iterator to be able to visit the nodes of the tree in
|
||||
preorder, can make a list with all the nodes in preorder
|
||||
|
@@ -454,3 +454,22 @@ class Stree_test(unittest.TestCase):
|
||||
clf = Stree(random_state=self._random_state)
|
||||
clf.fit(X, y)
|
||||
self.assertEqual(7, clf.depth_)
|
||||
|
||||
def test_nodes_leaves(self):
|
||||
X, y = load_dataset(
|
||||
random_state=self._random_state,
|
||||
n_classes=3,
|
||||
n_features=5,
|
||||
n_samples=1500,
|
||||
)
|
||||
clf = Stree(random_state=self._random_state)
|
||||
clf.fit(X, y)
|
||||
nodes, leaves = clf.nodes_leaves()
|
||||
self.assertEqual(12, nodes)
|
||||
self.assertEquals(13, leaves)
|
||||
X, y = load_wine(return_X_y=True)
|
||||
clf = Stree(random_state=self._random_state)
|
||||
clf.fit(X, y)
|
||||
nodes, leaves = clf.nodes_leaves()
|
||||
self.assertEqual(8, nodes)
|
||||
self.assertEquals(9, leaves)
|
||||
|
Reference in New Issue
Block a user