mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 00:02:30 +00:00
Add nodes, leaves and depth average
This commit is contained in:
parent
cfda03682b
commit
3a06c9d1cc
17
odte/Odte.py
17
odte/Odte.py
@ -83,8 +83,21 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
|
||||
self.subspaces_: List[Tuple[int, ...]] = []
|
||||
result = self._train(X, y, sample_weight)
|
||||
self.estimators_, self.subspaces_ = tuple(zip(*result)) # type: ignore
|
||||
self._compute_metrics()
|
||||
return self
|
||||
|
||||
def _compute_metrics(self) -> None:
|
||||
tdepth = tnodes = tleaves = 0
|
||||
for estimator in self.estimators_:
|
||||
nodes, leaves = estimator.nodes_leaves()
|
||||
depth = estimator.depth_
|
||||
tdepth += depth
|
||||
tnodes += nodes
|
||||
tleaves += leaves
|
||||
self.depth_ = tdepth / self.n_estimators
|
||||
self.leaves_ = tleaves / self.n_estimators
|
||||
self.nodes_ = tnodes / self.n_estimators
|
||||
|
||||
@staticmethod
|
||||
def _parallel_build_tree(
|
||||
base_estimator_: Stree,
|
||||
@ -228,3 +241,7 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
|
||||
for i in range(n_samples):
|
||||
result[i, predictions[i]] += 1
|
||||
return result / self.n_estimators
|
||||
|
||||
def nodes_leaves(self) -> list(float, float):
|
||||
check_is_fitted(self, "estimators_")
|
||||
return self.nodes_, self.leaves_
|
||||
|
@ -3,7 +3,7 @@ import unittest
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from sklearn.exceptions import ConvergenceWarning
|
||||
from sklearn.exceptions import ConvergenceWarning, NotFittedError
|
||||
|
||||
from odte import Odte
|
||||
from stree import Stree
|
||||
@ -191,3 +191,27 @@ class Odte_test(unittest.TestCase):
|
||||
from sklearn.utils.estimator_checks import check_estimator
|
||||
|
||||
check_estimator(Odte())
|
||||
|
||||
def test_nodes_leaves_not_fitted(self):
|
||||
tclf = Odte(
|
||||
base_estimator=Stree(),
|
||||
random_state=self._random_state,
|
||||
n_estimators=3,
|
||||
)
|
||||
with self.assertRaises(NotFittedError):
|
||||
tclf.nodes_leaves()
|
||||
|
||||
def test_nodes_leaves_depth(self):
|
||||
tclf = Odte(
|
||||
base_estimator=Stree(),
|
||||
random_state=self._random_state,
|
||||
n_estimators=3,
|
||||
)
|
||||
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
|
||||
tclf.fit(X, y)
|
||||
self.assertAlmostEqual(6.0, tclf.depth_)
|
||||
self.assertAlmostEqual(9.333333333333334, tclf.leaves_)
|
||||
self.assertAlmostEqual(17.666666666666668, tclf.nodes_)
|
||||
nodes, leaves = tclf.nodes_leaves()
|
||||
self.assertAlmostEqual(9.333333333333334, leaves)
|
||||
self.assertAlmostEqual(17.666666666666668, nodes)
|
||||
|
Loading…
x
Reference in New Issue
Block a user