Add nodes, leaves and depth average

This commit is contained in:
Ricardo Montañana Gómez 2021-11-10 13:35:13 +01:00
parent cfda03682b
commit 3a06c9d1cc
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 42 additions and 1 deletions

View File

@ -83,8 +83,21 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
self.subspaces_: List[Tuple[int, ...]] = []
result = self._train(X, y, sample_weight)
self.estimators_, self.subspaces_ = tuple(zip(*result)) # type: ignore
self._compute_metrics()
return self
def _compute_metrics(self) -> None:
tdepth = tnodes = tleaves = 0
for estimator in self.estimators_:
nodes, leaves = estimator.nodes_leaves()
depth = estimator.depth_
tdepth += depth
tnodes += nodes
tleaves += leaves
self.depth_ = tdepth / self.n_estimators
self.leaves_ = tleaves / self.n_estimators
self.nodes_ = tnodes / self.n_estimators
@staticmethod
def _parallel_build_tree(
base_estimator_: Stree,
@ -228,3 +241,7 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
for i in range(n_samples):
result[i, predictions[i]] += 1
return result / self.n_estimators
def nodes_leaves(self) -> list(float, float):
check_is_fitted(self, "estimators_")
return self.nodes_, self.leaves_

View File

@ -3,7 +3,7 @@ import unittest
import os
import random
import warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.exceptions import ConvergenceWarning, NotFittedError
from odte import Odte
from stree import Stree
@ -191,3 +191,27 @@ class Odte_test(unittest.TestCase):
from sklearn.utils.estimator_checks import check_estimator
check_estimator(Odte())
def test_nodes_leaves_not_fitted(self):
tclf = Odte(
base_estimator=Stree(),
random_state=self._random_state,
n_estimators=3,
)
with self.assertRaises(NotFittedError):
tclf.nodes_leaves()
def test_nodes_leaves_depth(self):
tclf = Odte(
base_estimator=Stree(),
random_state=self._random_state,
n_estimators=3,
)
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
tclf.fit(X, y)
self.assertAlmostEqual(6.0, tclf.depth_)
self.assertAlmostEqual(9.333333333333334, tclf.leaves_)
self.assertAlmostEqual(17.666666666666668, tclf.nodes_)
nodes, leaves = tclf.nodes_leaves()
self.assertAlmostEqual(9.333333333333334, leaves)
self.assertAlmostEqual(17.666666666666668, nodes)