Fix depth/leaves/nodes no longer return average

This commit is contained in:
Ricardo Montañana Gómez 2023-11-27 13:53:15 +01:00
parent 52d1095161
commit 02e75b3c3e
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
3 changed files with 24 additions and 21 deletions

View File

@ -91,7 +91,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
return self return self
def _compute_metrics(self) -> None: def _compute_metrics(self) -> None:
tdepth = tnodes = tleaves = 0.0 tdepth = tnodes = tleaves = 0
for estimator in self.estimators_: for estimator in self.estimators_:
if hasattr(estimator, "nodes_leaves"): if hasattr(estimator, "nodes_leaves"):
nodes, leaves = estimator.nodes_leaves() nodes, leaves = estimator.nodes_leaves()
@ -99,9 +99,12 @@ class Odte(BaseEnsemble, ClassifierMixin):
tdepth += depth tdepth += depth
tnodes += nodes tnodes += nodes
tleaves += leaves tleaves += leaves
self.depth_ = tdepth / self.n_estimators # self.depth_ = tdepth / self.n_estimators
self.leaves_ = tleaves / self.n_estimators # self.leaves_ = tleaves / self.n_estimators
self.nodes_ = tnodes / self.n_estimators # self.nodes_ = tnodes / self.n_estimators
self.depth_ = tdepth
self.leaves_ = tleaves
self.nodes_ = tnodes
def _train( def _train(
self, X: np.ndarray, y: np.ndarray, weights: np.ndarray self, X: np.ndarray, y: np.ndarray, weights: np.ndarray

View File

@ -1 +1 @@
__version__ = "0.3.5" __version__ = "0.3.6"

View File

@ -214,17 +214,17 @@ class Odte_test(unittest.TestCase):
tclf.fit(X, y) tclf.fit(X, y)
tclf_p.fit(X, y) tclf_p.fit(X, y)
for clf in [tclf, tclf_p]: for clf in [tclf, tclf_p]:
self.assertAlmostEqual(5.8, clf.depth_) self.assertEqual(29, clf.depth_)
self.assertAlmostEqual(5.8, clf.get_depth()) self.assertEqual(29, clf.get_depth())
self.assertAlmostEqual(9.4, clf.leaves_) self.assertEqual(47, clf.leaves_)
self.assertAlmostEqual(9.4, clf.get_leaves()) self.assertEqual(47, clf.get_leaves())
self.assertAlmostEqual(17.8, clf.nodes_) self.assertEqual(89, clf.nodes_)
self.assertAlmostEqual(17.8, clf.get_nodes()) self.assertEqual(89, clf.get_nodes())
nodes, leaves = clf.nodes_leaves() nodes, leaves = clf.nodes_leaves()
self.assertAlmostEqual(9.4, leaves) self.assertEqual(47, leaves)
self.assertAlmostEqual(9.4, clf.get_leaves()) self.assertEqual(47, clf.get_leaves())
self.assertAlmostEqual(17.8, nodes) self.assertEqual(89, nodes)
self.assertAlmostEqual(17.8, clf.get_nodes()) self.assertEqual(89, clf.get_nodes())
def test_nodes_leaves_SVC(self): def test_nodes_leaves_SVC(self):
tclf = Odte( tclf = Odte(