mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 00:46:02 +00:00
Add separate methods to return nodes/leaves/depth
This commit is contained in:
@@ -484,6 +484,43 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
X = self.check_predict(X)
|
X = self.check_predict(X)
|
||||||
return self.classes_[np.argmax(self.__predict_class(X), axis=1)]
|
return self.classes_[np.argmax(self.__predict_class(X), axis=1)]
|
||||||
|
|
||||||
|
def get_nodes(self) -> int:
|
||||||
|
"""Return the number of nodes in the tree
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
number of nodes
|
||||||
|
"""
|
||||||
|
nodes = 0
|
||||||
|
for _ in self:
|
||||||
|
nodes += 1
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
def get_leaves(self) -> int:
|
||||||
|
"""Return the number of leaves in the tree
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
number of leaves
|
||||||
|
"""
|
||||||
|
leaves = 0
|
||||||
|
for node in self:
|
||||||
|
if node.is_leaf():
|
||||||
|
leaves += 1
|
||||||
|
return leaves
|
||||||
|
|
||||||
|
def get_depth(self) -> int:
|
||||||
|
"""Return the depth of the tree
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
depth of the tree
|
||||||
|
"""
|
||||||
|
return self.depth_
|
||||||
|
|
||||||
def nodes_leaves(self) -> tuple:
|
def nodes_leaves(self) -> tuple:
|
||||||
"""Compute the number of nodes and leaves in the built tree
|
"""Compute the number of nodes and leaves in the built tree
|
||||||
|
|
||||||
|
@@ -1 +1 @@
|
|||||||
__version__ = "1.3.1"
|
__version__ = "1.3.2"
|
||||||
|
@@ -239,6 +239,7 @@ class Stree_test(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
tcl.fit(*load_dataset(self._random_state))
|
tcl.fit(*load_dataset(self._random_state))
|
||||||
self.assertEqual(depth, tcl.depth_)
|
self.assertEqual(depth, tcl.depth_)
|
||||||
|
self.assertEqual(depth, tcl.get_depth())
|
||||||
|
|
||||||
def test_unfitted_tree_is_iterable(self):
|
def test_unfitted_tree_is_iterable(self):
|
||||||
tcl = Stree()
|
tcl = Stree()
|
||||||
@@ -640,10 +641,12 @@ class Stree_test(unittest.TestCase):
|
|||||||
clf = Stree(random_state=self._random_state)
|
clf = Stree(random_state=self._random_state)
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
self.assertEqual(6, clf.depth_)
|
self.assertEqual(6, clf.depth_)
|
||||||
|
self.assertEqual(6, clf.get_depth())
|
||||||
X, y = load_wine(return_X_y=True)
|
X, y = load_wine(return_X_y=True)
|
||||||
clf = Stree(random_state=self._random_state)
|
clf = Stree(random_state=self._random_state)
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
self.assertEqual(4, clf.depth_)
|
self.assertEqual(4, clf.depth_)
|
||||||
|
self.assertEqual(4, clf.get_depth())
|
||||||
|
|
||||||
def test_nodes_leaves(self):
|
def test_nodes_leaves(self):
|
||||||
"""Check number of nodes and leaves."""
|
"""Check number of nodes and leaves."""
|
||||||
@@ -657,13 +660,17 @@ class Stree_test(unittest.TestCase):
|
|||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
nodes, leaves = clf.nodes_leaves()
|
nodes, leaves = clf.nodes_leaves()
|
||||||
self.assertEqual(31, nodes)
|
self.assertEqual(31, nodes)
|
||||||
|
self.assertEqual(31, clf.get_nodes())
|
||||||
self.assertEqual(16, leaves)
|
self.assertEqual(16, leaves)
|
||||||
|
self.assertEqual(16, clf.get_leaves())
|
||||||
X, y = load_wine(return_X_y=True)
|
X, y = load_wine(return_X_y=True)
|
||||||
clf = Stree(random_state=self._random_state)
|
clf = Stree(random_state=self._random_state)
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
nodes, leaves = clf.nodes_leaves()
|
nodes, leaves = clf.nodes_leaves()
|
||||||
self.assertEqual(11, nodes)
|
self.assertEqual(11, nodes)
|
||||||
|
self.assertEqual(11, clf.get_nodes())
|
||||||
self.assertEqual(6, leaves)
|
self.assertEqual(6, leaves)
|
||||||
|
self.assertEqual(6, clf.get_leaves())
|
||||||
|
|
||||||
def test_nodes_leaves_artificial(self):
|
def test_nodes_leaves_artificial(self):
|
||||||
"""Check leaves of artificial dataset."""
|
"""Check leaves of artificial dataset."""
|
||||||
@@ -682,7 +689,9 @@ class Stree_test(unittest.TestCase):
|
|||||||
clf.tree_ = n1
|
clf.tree_ = n1
|
||||||
nodes, leaves = clf.nodes_leaves()
|
nodes, leaves = clf.nodes_leaves()
|
||||||
self.assertEqual(6, nodes)
|
self.assertEqual(6, nodes)
|
||||||
|
self.assertEqual(6, clf.get_nodes())
|
||||||
self.assertEqual(2, leaves)
|
self.assertEqual(2, leaves)
|
||||||
|
self.assertEqual(2, clf.get_leaves())
|
||||||
|
|
||||||
def test_bogus_multiclass_strategy(self):
|
def test_bogus_multiclass_strategy(self):
|
||||||
"""Check invalid multiclass strategy."""
|
"""Check invalid multiclass strategy."""
|
||||||
|
Reference in New Issue
Block a user