From 02e75b3c3e3e7409fae47b781f860dafaa429645 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 27 Nov 2023 13:53:15 +0100 Subject: [PATCH] Fix depth/leaves/nodes no longer return average --- odte/Odte.py | 23 +++++++++++++---------- odte/_version.py | 2 +- odte/tests/Odte_tests.py | 20 ++++++++++---------- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/odte/Odte.py b/odte/Odte.py index b7c08b3..179cf5b 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -91,7 +91,7 @@ class Odte(BaseEnsemble, ClassifierMixin): return self def _compute_metrics(self) -> None: - tdepth = tnodes = tleaves = 0.0 + tdepth = tnodes = tleaves = 0 for estimator in self.estimators_: if hasattr(estimator, "nodes_leaves"): nodes, leaves = estimator.nodes_leaves() @@ -99,9 +99,12 @@ class Odte(BaseEnsemble, ClassifierMixin): tdepth += depth tnodes += nodes tleaves += leaves - self.depth_ = tdepth / self.n_estimators - self.leaves_ = tleaves / self.n_estimators - self.nodes_ = tnodes / self.n_estimators + # self.depth_ = tdepth / self.n_estimators + # self.leaves_ = tleaves / self.n_estimators + # self.nodes_ = tnodes / self.n_estimators + self.depth_ = tdepth + self.leaves_ = tleaves + self.nodes_ = tnodes def _train( self, X: np.ndarray, y: np.ndarray, weights: np.ndarray @@ -250,16 +253,16 @@ class Odte(BaseEnsemble, ClassifierMixin): for i in range(n_samples): result[i, predictions[i]] += 1 return result / self.n_estimators - - def get_nodes(self) ->int: + + def get_nodes(self) -> int: check_is_fitted(self, "estimators_") return self.nodes_ - - def get_leaves(self) ->int: + + def get_leaves(self) -> int: check_is_fitted(self, "estimators_") return self.leaves_ - - def get_depth(self) ->int: + + def get_depth(self) -> int: check_is_fitted(self, "estimators_") return self.depth_ diff --git a/odte/_version.py b/odte/_version.py index a8d4557..d7b30e1 100644 --- a/odte/_version.py +++ b/odte/_version.py @@ -1 +1 @@ -__version__ = "0.3.5" +__version__ = "0.3.6" diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index 6550a94..79763a4 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -214,17 +214,17 @@ class Odte_test(unittest.TestCase): tclf.fit(X, y) tclf_p.fit(X, y) for clf in [tclf, tclf_p]: - self.assertAlmostEqual(5.8, clf.depth_) - self.assertAlmostEqual(5.8, clf.get_depth()) - self.assertAlmostEqual(9.4, clf.leaves_) - self.assertAlmostEqual(9.4, clf.get_leaves()) - self.assertAlmostEqual(17.8, clf.nodes_) - self.assertAlmostEqual(17.8, clf.get_nodes()) + self.assertEqual(29, clf.depth_) + self.assertEqual(29, clf.get_depth()) + self.assertEqual(47, clf.leaves_) + self.assertEqual(47, clf.get_leaves()) + self.assertEqual(89, clf.nodes_) + self.assertEqual(89, clf.get_nodes()) nodes, leaves = clf.nodes_leaves() - self.assertAlmostEqual(9.4, leaves) - self.assertAlmostEqual(9.4, clf.get_leaves()) - self.assertAlmostEqual(17.8, nodes) - self.assertAlmostEqual(17.8, clf.get_nodes()) + self.assertEqual(47, leaves) + self.assertEqual(47, clf.get_leaves()) + self.assertEqual(89, nodes) + self.assertEqual(89, clf.get_nodes()) def test_nodes_leaves_SVC(self): tclf = Odte(