mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 00:02:30 +00:00
Add nodes, leaves and depth average
This commit is contained in:
parent
cfda03682b
commit
3a06c9d1cc
17
odte/Odte.py
17
odte/Odte.py
@ -83,8 +83,21 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
|
|||||||
self.subspaces_: List[Tuple[int, ...]] = []
|
self.subspaces_: List[Tuple[int, ...]] = []
|
||||||
result = self._train(X, y, sample_weight)
|
result = self._train(X, y, sample_weight)
|
||||||
self.estimators_, self.subspaces_ = tuple(zip(*result)) # type: ignore
|
self.estimators_, self.subspaces_ = tuple(zip(*result)) # type: ignore
|
||||||
|
self._compute_metrics()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _compute_metrics(self) -> None:
|
||||||
|
tdepth = tnodes = tleaves = 0
|
||||||
|
for estimator in self.estimators_:
|
||||||
|
nodes, leaves = estimator.nodes_leaves()
|
||||||
|
depth = estimator.depth_
|
||||||
|
tdepth += depth
|
||||||
|
tnodes += nodes
|
||||||
|
tleaves += leaves
|
||||||
|
self.depth_ = tdepth / self.n_estimators
|
||||||
|
self.leaves_ = tleaves / self.n_estimators
|
||||||
|
self.nodes_ = tnodes / self.n_estimators
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parallel_build_tree(
|
def _parallel_build_tree(
|
||||||
base_estimator_: Stree,
|
base_estimator_: Stree,
|
||||||
@ -228,3 +241,7 @@ class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
|
|||||||
for i in range(n_samples):
|
for i in range(n_samples):
|
||||||
result[i, predictions[i]] += 1
|
result[i, predictions[i]] += 1
|
||||||
return result / self.n_estimators
|
return result / self.n_estimators
|
||||||
|
|
||||||
|
def nodes_leaves(self) -> list(float, float):
|
||||||
|
check_is_fitted(self, "estimators_")
|
||||||
|
return self.nodes_, self.leaves_
|
||||||
|
@ -3,7 +3,7 @@ import unittest
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from sklearn.exceptions import ConvergenceWarning
|
from sklearn.exceptions import ConvergenceWarning, NotFittedError
|
||||||
|
|
||||||
from odte import Odte
|
from odte import Odte
|
||||||
from stree import Stree
|
from stree import Stree
|
||||||
@ -191,3 +191,27 @@ class Odte_test(unittest.TestCase):
|
|||||||
from sklearn.utils.estimator_checks import check_estimator
|
from sklearn.utils.estimator_checks import check_estimator
|
||||||
|
|
||||||
check_estimator(Odte())
|
check_estimator(Odte())
|
||||||
|
|
||||||
|
def test_nodes_leaves_not_fitted(self):
|
||||||
|
tclf = Odte(
|
||||||
|
base_estimator=Stree(),
|
||||||
|
random_state=self._random_state,
|
||||||
|
n_estimators=3,
|
||||||
|
)
|
||||||
|
with self.assertRaises(NotFittedError):
|
||||||
|
tclf.nodes_leaves()
|
||||||
|
|
||||||
|
def test_nodes_leaves_depth(self):
|
||||||
|
tclf = Odte(
|
||||||
|
base_estimator=Stree(),
|
||||||
|
random_state=self._random_state,
|
||||||
|
n_estimators=3,
|
||||||
|
)
|
||||||
|
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
|
||||||
|
tclf.fit(X, y)
|
||||||
|
self.assertAlmostEqual(6.0, tclf.depth_)
|
||||||
|
self.assertAlmostEqual(9.333333333333334, tclf.leaves_)
|
||||||
|
self.assertAlmostEqual(17.666666666666668, tclf.nodes_)
|
||||||
|
nodes, leaves = tclf.nodes_leaves()
|
||||||
|
self.assertAlmostEqual(9.333333333333334, leaves)
|
||||||
|
self.assertAlmostEqual(17.666666666666668, nodes)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user