diff --git a/stree/Strees.py b/stree/Strees.py index 7c29a2e..c36f190 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -818,6 +818,23 @@ class Stree(BaseEstimator, ClassifierMixin): score = y_true == y_pred return _weighted_sum(score, sample_weight, normalize=True) + def nodes_leaves(self) -> tuple: + """Compute the number of nodes and leaves in the built tree + + Returns + ------- + [tuple] + tuple with the number of nodes and the number of leaves + """ + nodes = 0 + leaves = 0 + for node in self: + if node.is_leaf(): + leaves += 1 + else: + nodes += 1 + return nodes, leaves + def __iter__(self) -> Siterator: """Create an iterator to be able to visit the nodes of the tree in preorder, can make a list with all the nodes in preorder diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 319f60f..5deb372 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -454,3 +454,22 @@ class Stree_test(unittest.TestCase): clf = Stree(random_state=self._random_state) clf.fit(X, y) self.assertEqual(7, clf.depth_) + + def test_nodes_leaves(self): + X, y = load_dataset( + random_state=self._random_state, + n_classes=3, + n_features=5, + n_samples=1500, + ) + clf = Stree(random_state=self._random_state) + clf.fit(X, y) + nodes, leaves = clf.nodes_leaves() + self.assertEqual(12, nodes) + self.assertEquals(13, leaves) + X, y = load_wine(return_X_y=True) + clf = Stree(random_state=self._random_state) + clf.fit(X, y) + nodes, leaves = clf.nodes_leaves() + self.assertEqual(8, nodes) + self.assertEquals(9, leaves)