diff --git a/stree/Strees.py b/stree/Strees.py index 0cfaa66..d44da0f 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -841,10 +841,9 @@ class Stree(BaseEstimator, ClassifierMixin): nodes = 0 leaves = 0 for node in self: + nodes += 1 if node.is_leaf(): leaves += 1 - else: - nodes += 1 return nodes, leaves def __iter__(self) -> Siterator: diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 4b01201..ce50ae5 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -465,13 +465,13 @@ class Stree_test(unittest.TestCase): clf = Stree(random_state=self._random_state) clf.fit(X, y) nodes, leaves = clf.nodes_leaves() - self.assertEqual(12, nodes) + self.assertEqual(25, nodes) self.assertEquals(13, leaves) X, y = load_wine(return_X_y=True) clf = Stree(random_state=self._random_state) clf.fit(X, y) nodes, leaves = clf.nodes_leaves() - self.assertEqual(4, nodes) + self.assertEqual(9, nodes) self.assertEquals(5, leaves) def test_nodes_leaves_artificial(self): @@ -489,5 +489,5 @@ class Stree_test(unittest.TestCase): clf = Stree(random_state=self._random_state) clf.tree_ = n1 nodes, leaves = clf.nodes_leaves() - self.assertEqual(4, nodes) + self.assertEqual(6, nodes) self.assertEqual(2, leaves)