mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 23:46:02 +00:00
Fix compute number of nodes
This commit is contained in:
@@ -841,10 +841,9 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
nodes = 0
|
nodes = 0
|
||||||
leaves = 0
|
leaves = 0
|
||||||
for node in self:
|
for node in self:
|
||||||
|
nodes += 1
|
||||||
if node.is_leaf():
|
if node.is_leaf():
|
||||||
leaves += 1
|
leaves += 1
|
||||||
else:
|
|
||||||
nodes += 1
|
|
||||||
return nodes, leaves
|
return nodes, leaves
|
||||||
|
|
||||||
def __iter__(self) -> Siterator:
|
def __iter__(self) -> Siterator:
|
||||||
|
@@ -465,13 +465,13 @@ class Stree_test(unittest.TestCase):
|
|||||||
clf = Stree(random_state=self._random_state)
|
clf = Stree(random_state=self._random_state)
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
nodes, leaves = clf.nodes_leaves()
|
nodes, leaves = clf.nodes_leaves()
|
||||||
self.assertEqual(12, nodes)
|
self.assertEqual(25, nodes)
|
||||||
self.assertEquals(13, leaves)
|
self.assertEquals(13, leaves)
|
||||||
X, y = load_wine(return_X_y=True)
|
X, y = load_wine(return_X_y=True)
|
||||||
clf = Stree(random_state=self._random_state)
|
clf = Stree(random_state=self._random_state)
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
nodes, leaves = clf.nodes_leaves()
|
nodes, leaves = clf.nodes_leaves()
|
||||||
self.assertEqual(4, nodes)
|
self.assertEqual(9, nodes)
|
||||||
self.assertEquals(5, leaves)
|
self.assertEquals(5, leaves)
|
||||||
|
|
||||||
def test_nodes_leaves_artificial(self):
|
def test_nodes_leaves_artificial(self):
|
||||||
@@ -489,5 +489,5 @@ class Stree_test(unittest.TestCase):
|
|||||||
clf = Stree(random_state=self._random_state)
|
clf = Stree(random_state=self._random_state)
|
||||||
clf.tree_ = n1
|
clf.tree_ = n1
|
||||||
nodes, leaves = clf.nodes_leaves()
|
nodes, leaves = clf.nodes_leaves()
|
||||||
self.assertEqual(4, nodes)
|
self.assertEqual(6, nodes)
|
||||||
self.assertEqual(2, leaves)
|
self.assertEqual(2, leaves)
|
||||||
|
Reference in New Issue
Block a user