diff --git a/odte/Odte.py b/odte/Odte.py index c7f03bf..f071707 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -83,8 +83,21 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore self.subspaces_: List[Tuple[int, ...]] = [] result = self._train(X, y, sample_weight) self.estimators_, self.subspaces_ = tuple(zip(*result)) # type: ignore + self._compute_metrics() return self + def _compute_metrics(self) -> None: + tdepth = tnodes = tleaves = 0 + for estimator in self.estimators_: + nodes, leaves = estimator.nodes_leaves() + depth = estimator.depth_ + tdepth += depth + tnodes += nodes + tleaves += leaves + self.depth_ = tdepth / self.n_estimators + self.leaves_ = tleaves / self.n_estimators + self.nodes_ = tnodes / self.n_estimators + @staticmethod def _parallel_build_tree( base_estimator_: Stree, @@ -228,3 +241,7 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore for i in range(n_samples): result[i, predictions[i]] += 1 return result / self.n_estimators + + def nodes_leaves(self) -> list(float, float): + check_is_fitted(self, "estimators_") + return self.nodes_, self.leaves_ diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index 267172e..d52dd95 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -3,7 +3,7 @@ import unittest import os import random import warnings -from sklearn.exceptions import ConvergenceWarning +from sklearn.exceptions import ConvergenceWarning, NotFittedError from odte import Odte from stree import Stree @@ -191,3 +191,27 @@ class Odte_test(unittest.TestCase): from sklearn.utils.estimator_checks import check_estimator check_estimator(Odte()) + + def test_nodes_leaves_not_fitted(self): + tclf = Odte( + base_estimator=Stree(), + random_state=self._random_state, + n_estimators=3, + ) + with self.assertRaises(NotFittedError): + tclf.nodes_leaves() + + def test_nodes_leaves_depth(self): + tclf = Odte( + base_estimator=Stree(), + random_state=self._random_state, + n_estimators=3, + ) + X, y = load_dataset(self._random_state, n_features=16, n_samples=500) + tclf.fit(X, y) + self.assertAlmostEqual(6.0, tclf.depth_) + self.assertAlmostEqual(9.333333333333334, tclf.leaves_) + self.assertAlmostEqual(17.666666666666668, tclf.nodes_) + nodes, leaves = tclf.nodes_leaves() + self.assertAlmostEqual(9.333333333333334, leaves) + self.assertAlmostEqual(17.666666666666668, nodes)