From 52d10951618c64d341db6032edfcf21750a03a55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Mon, 27 Nov 2023 10:33:47 +0100 Subject: [PATCH] Add separate methods to return nodes/leaves/depth --- odte/Odte.py | 12 ++++++++++++ odte/_version.py | 2 +- odte/tests/Odte_tests.py | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/odte/Odte.py b/odte/Odte.py index 9b9f26c..b7c08b3 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -250,6 +250,18 @@ class Odte(BaseEnsemble, ClassifierMixin): for i in range(n_samples): result[i, predictions[i]] += 1 return result / self.n_estimators + + def get_nodes(self) ->int: + check_is_fitted(self, "estimators_") + return self.nodes_ + + def get_leaves(self) ->int: + check_is_fitted(self, "estimators_") + return self.leaves_ + + def get_depth(self) ->int: + check_is_fitted(self, "estimators_") + return self.depth_ def nodes_leaves(self) -> Tuple[float, float]: check_is_fitted(self, "estimators_") diff --git a/odte/_version.py b/odte/_version.py index 334b899..a8d4557 100644 --- a/odte/_version.py +++ b/odte/_version.py @@ -1 +1 @@ -__version__ = "0.3.4" +__version__ = "0.3.5" diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index baa572e..6550a94 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -190,6 +190,12 @@ class Odte_test(unittest.TestCase): ) with self.assertRaises(NotFittedError): tclf.nodes_leaves() + with self.assertRaises(NotFittedError): + tclf.get_nodes() + with self.assertRaises(NotFittedError): + tclf.get_leaves() + with self.assertRaises(NotFittedError): + tclf.get_depth() def test_nodes_leaves_depth(self): tclf = Odte( @@ -209,11 +215,16 @@ class Odte_test(unittest.TestCase): tclf_p.fit(X, y) for clf in [tclf, tclf_p]: self.assertAlmostEqual(5.8, clf.depth_) + self.assertAlmostEqual(5.8, clf.get_depth()) self.assertAlmostEqual(9.4, clf.leaves_) + self.assertAlmostEqual(9.4, clf.get_leaves()) self.assertAlmostEqual(17.8, clf.nodes_) + self.assertAlmostEqual(17.8, clf.get_nodes()) nodes, leaves = clf.nodes_leaves() self.assertAlmostEqual(9.4, leaves) + self.assertAlmostEqual(9.4, clf.get_leaves()) self.assertAlmostEqual(17.8, nodes) + self.assertAlmostEqual(17.8, clf.get_nodes()) def test_nodes_leaves_SVC(self): tclf = Odte( @@ -224,10 +235,14 @@ class Odte_test(unittest.TestCase): X, y = load_dataset(self._random_state, n_features=16, n_samples=500) tclf.fit(X, y) self.assertAlmostEqual(0.0, tclf.leaves_) + self.assertAlmostEqual(0.0, tclf.get_leaves()) self.assertAlmostEqual(0.0, tclf.nodes_) + self.assertAlmostEqual(0.0, tclf.get_nodes()) nodes, leaves = tclf.nodes_leaves() self.assertAlmostEqual(0.0, leaves) + self.assertAlmostEqual(0.0, tclf.get_leaves()) self.assertAlmostEqual(0.0, nodes) + self.assertAlmostEqual(0.0, tclf.get_nodes()) def test_estimator_hyperparams(self): data = [