Add nodes, leaves and depth average

This commit is contained in:
Ricardo Montañana Gómez 2021-11-10 13:35:13 +01:00
parent cfda03682b
commit 3a06c9d1cc
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
2 changed files with 42 additions and 1 deletions

View File

@ -83,8 +83,21 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
self.subspaces_: List[Tuple[int, ...]] = [] self.subspaces_: List[Tuple[int, ...]] = []
result = self._train(X, y, sample_weight) result = self._train(X, y, sample_weight)
self.estimators_, self.subspaces_ = tuple(zip(*result)) # type: ignore self.estimators_, self.subspaces_ = tuple(zip(*result)) # type: ignore
self._compute_metrics()
return self return self
def _compute_metrics(self) -> None:
tdepth = tnodes = tleaves = 0
for estimator in self.estimators_:
nodes, leaves = estimator.nodes_leaves()
depth = estimator.depth_
tdepth += depth
tnodes += nodes
tleaves += leaves
self.depth_ = tdepth / self.n_estimators
self.leaves_ = tleaves / self.n_estimators
self.nodes_ = tnodes / self.n_estimators
@staticmethod @staticmethod
def _parallel_build_tree( def _parallel_build_tree(
base_estimator_: Stree, base_estimator_: Stree,
@ -228,3 +241,7 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
for i in range(n_samples): for i in range(n_samples):
result[i, predictions[i]] += 1 result[i, predictions[i]] += 1
return result / self.n_estimators return result / self.n_estimators
def nodes_leaves(self) -> list(float, float):
check_is_fitted(self, "estimators_")
return self.nodes_, self.leaves_

View File

@ -3,7 +3,7 @@ import unittest
import os import os
import random import random
import warnings import warnings
from sklearn.exceptions import ConvergenceWarning from sklearn.exceptions import ConvergenceWarning, NotFittedError
from odte import Odte from odte import Odte
from stree import Stree from stree import Stree
@ -191,3 +191,27 @@ class Odte_test(unittest.TestCase):
from sklearn.utils.estimator_checks import check_estimator from sklearn.utils.estimator_checks import check_estimator
check_estimator(Odte()) check_estimator(Odte())
def test_nodes_leaves_not_fitted(self):
tclf = Odte(
base_estimator=Stree(),
random_state=self._random_state,
n_estimators=3,
)
with self.assertRaises(NotFittedError):
tclf.nodes_leaves()
def test_nodes_leaves_depth(self):
tclf = Odte(
base_estimator=Stree(),
random_state=self._random_state,
n_estimators=3,
)
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
tclf.fit(X, y)
self.assertAlmostEqual(6.0, tclf.depth_)
self.assertAlmostEqual(9.333333333333334, tclf.leaves_)
self.assertAlmostEqual(17.666666666666668, tclf.nodes_)
nodes, leaves = tclf.nodes_leaves()
self.assertAlmostEqual(9.333333333333334, leaves)
self.assertAlmostEqual(17.666666666666668, nodes)