Fix compute number of nodes

This commit is contained in:
2021-04-13 22:31:05 +02:00
parent 783d105099
commit b55f59a3ec
2 changed files with 4 additions and 5 deletions

View File

@@ -841,10 +841,9 @@ class Stree(BaseEstimator, ClassifierMixin):
nodes = 0
leaves = 0
for node in self:
nodes += 1
if node.is_leaf():
leaves += 1
else:
nodes += 1
return nodes, leaves
def __iter__(self) -> Siterator:

View File

@@ -465,13 +465,13 @@ class Stree_test(unittest.TestCase):
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
nodes, leaves = clf.nodes_leaves()
self.assertEqual(12, nodes)
self.assertEqual(25, nodes)
self.assertEquals(13, leaves)
X, y = load_wine(return_X_y=True)
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
nodes, leaves = clf.nodes_leaves()
self.assertEqual(4, nodes)
self.assertEqual(9, nodes)
self.assertEquals(5, leaves)
def test_nodes_leaves_artificial(self):
@@ -489,5 +489,5 @@ class Stree_test(unittest.TestCase):
clf = Stree(random_state=self._random_state)
clf.tree_ = n1
nodes, leaves = clf.nodes_leaves()
self.assertEqual(4, nodes)
self.assertEqual(6, nodes)
self.assertEqual(2, leaves)