Add separate methods to return nodes/leaves/depth

This commit is contained in:
Ricardo Montañana Gómez 2023-11-27 10:33:47 +01:00
parent f9b83adfee
commit 52d1095161
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 28 additions and 1 deletions

View File

@ -251,6 +251,18 @@ class Odte(BaseEnsemble, ClassifierMixin):
result[i, predictions[i]] += 1
return result / self.n_estimators
def get_nodes(self) ->int:
check_is_fitted(self, "estimators_")
return self.nodes_
def get_leaves(self) ->int:
check_is_fitted(self, "estimators_")
return self.leaves_
def get_depth(self) ->int:
check_is_fitted(self, "estimators_")
return self.depth_
def nodes_leaves(self) -> Tuple[float, float]:
check_is_fitted(self, "estimators_")
return self.nodes_, self.leaves_

View File

@ -1 +1 @@
__version__ = "0.3.4"
__version__ = "0.3.5"

View File

@ -190,6 +190,12 @@ class Odte_test(unittest.TestCase):
)
with self.assertRaises(NotFittedError):
tclf.nodes_leaves()
with self.assertRaises(NotFittedError):
tclf.get_nodes()
with self.assertRaises(NotFittedError):
tclf.get_leaves()
with self.assertRaises(NotFittedError):
tclf.get_depth()
def test_nodes_leaves_depth(self):
tclf = Odte(
@ -209,11 +215,16 @@ class Odte_test(unittest.TestCase):
tclf_p.fit(X, y)
for clf in [tclf, tclf_p]:
self.assertAlmostEqual(5.8, clf.depth_)
self.assertAlmostEqual(5.8, clf.get_depth())
self.assertAlmostEqual(9.4, clf.leaves_)
self.assertAlmostEqual(9.4, clf.get_leaves())
self.assertAlmostEqual(17.8, clf.nodes_)
self.assertAlmostEqual(17.8, clf.get_nodes())
nodes, leaves = clf.nodes_leaves()
self.assertAlmostEqual(9.4, leaves)
self.assertAlmostEqual(9.4, clf.get_leaves())
self.assertAlmostEqual(17.8, nodes)
self.assertAlmostEqual(17.8, clf.get_nodes())
def test_nodes_leaves_SVC(self):
tclf = Odte(
@ -224,10 +235,14 @@ class Odte_test(unittest.TestCase):
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
tclf.fit(X, y)
self.assertAlmostEqual(0.0, tclf.leaves_)
self.assertAlmostEqual(0.0, tclf.get_leaves())
self.assertAlmostEqual(0.0, tclf.nodes_)
self.assertAlmostEqual(0.0, tclf.get_nodes())
nodes, leaves = tclf.nodes_leaves()
self.assertAlmostEqual(0.0, leaves)
self.assertAlmostEqual(0.0, tclf.get_leaves())
self.assertAlmostEqual(0.0, nodes)
self.assertAlmostEqual(0.0, tclf.get_nodes())
def test_estimator_hyperparams(self):
data = [