mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 23:46:02 +00:00
Fix compute number of nodes
This commit is contained in:
@@ -841,10 +841,9 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
nodes = 0
|
||||
leaves = 0
|
||||
for node in self:
|
||||
nodes += 1
|
||||
if node.is_leaf():
|
||||
leaves += 1
|
||||
else:
|
||||
nodes += 1
|
||||
return nodes, leaves
|
||||
|
||||
def __iter__(self) -> Siterator:
|
||||
|
@@ -465,13 +465,13 @@ class Stree_test(unittest.TestCase):
|
||||
clf = Stree(random_state=self._random_state)
|
||||
clf.fit(X, y)
|
||||
nodes, leaves = clf.nodes_leaves()
|
||||
self.assertEqual(12, nodes)
|
||||
self.assertEqual(25, nodes)
|
||||
self.assertEquals(13, leaves)
|
||||
X, y = load_wine(return_X_y=True)
|
||||
clf = Stree(random_state=self._random_state)
|
||||
clf.fit(X, y)
|
||||
nodes, leaves = clf.nodes_leaves()
|
||||
self.assertEqual(4, nodes)
|
||||
self.assertEqual(9, nodes)
|
||||
self.assertEquals(5, leaves)
|
||||
|
||||
def test_nodes_leaves_artificial(self):
|
||||
@@ -489,5 +489,5 @@ class Stree_test(unittest.TestCase):
|
||||
clf = Stree(random_state=self._random_state)
|
||||
clf.tree_ = n1
|
||||
nodes, leaves = clf.nodes_leaves()
|
||||
self.assertEqual(4, nodes)
|
||||
self.assertEqual(6, nodes)
|
||||
self.assertEqual(2, leaves)
|
||||
|
Reference in New Issue
Block a user