mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 08:12:06 +00:00
Add separate methods to return nodes/leaves/depth
This commit is contained in:
parent
f9b83adfee
commit
52d1095161
12
odte/Odte.py
12
odte/Odte.py
@ -250,6 +250,18 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
||||
for i in range(n_samples):
|
||||
result[i, predictions[i]] += 1
|
||||
return result / self.n_estimators
|
||||
|
||||
def get_nodes(self) ->int:
|
||||
check_is_fitted(self, "estimators_")
|
||||
return self.nodes_
|
||||
|
||||
def get_leaves(self) ->int:
|
||||
check_is_fitted(self, "estimators_")
|
||||
return self.leaves_
|
||||
|
||||
def get_depth(self) ->int:
|
||||
check_is_fitted(self, "estimators_")
|
||||
return self.depth_
|
||||
|
||||
def nodes_leaves(self) -> Tuple[float, float]:
|
||||
check_is_fitted(self, "estimators_")
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "0.3.4"
|
||||
__version__ = "0.3.5"
|
||||
|
@ -190,6 +190,12 @@ class Odte_test(unittest.TestCase):
|
||||
)
|
||||
with self.assertRaises(NotFittedError):
|
||||
tclf.nodes_leaves()
|
||||
with self.assertRaises(NotFittedError):
|
||||
tclf.get_nodes()
|
||||
with self.assertRaises(NotFittedError):
|
||||
tclf.get_leaves()
|
||||
with self.assertRaises(NotFittedError):
|
||||
tclf.get_depth()
|
||||
|
||||
def test_nodes_leaves_depth(self):
|
||||
tclf = Odte(
|
||||
@ -209,11 +215,16 @@ class Odte_test(unittest.TestCase):
|
||||
tclf_p.fit(X, y)
|
||||
for clf in [tclf, tclf_p]:
|
||||
self.assertAlmostEqual(5.8, clf.depth_)
|
||||
self.assertAlmostEqual(5.8, clf.get_depth())
|
||||
self.assertAlmostEqual(9.4, clf.leaves_)
|
||||
self.assertAlmostEqual(9.4, clf.get_leaves())
|
||||
self.assertAlmostEqual(17.8, clf.nodes_)
|
||||
self.assertAlmostEqual(17.8, clf.get_nodes())
|
||||
nodes, leaves = clf.nodes_leaves()
|
||||
self.assertAlmostEqual(9.4, leaves)
|
||||
self.assertAlmostEqual(9.4, clf.get_leaves())
|
||||
self.assertAlmostEqual(17.8, nodes)
|
||||
self.assertAlmostEqual(17.8, clf.get_nodes())
|
||||
|
||||
def test_nodes_leaves_SVC(self):
|
||||
tclf = Odte(
|
||||
@ -224,10 +235,14 @@ class Odte_test(unittest.TestCase):
|
||||
X, y = load_dataset(self._random_state, n_features=16, n_samples=500)
|
||||
tclf.fit(X, y)
|
||||
self.assertAlmostEqual(0.0, tclf.leaves_)
|
||||
self.assertAlmostEqual(0.0, tclf.get_leaves())
|
||||
self.assertAlmostEqual(0.0, tclf.nodes_)
|
||||
self.assertAlmostEqual(0.0, tclf.get_nodes())
|
||||
nodes, leaves = tclf.nodes_leaves()
|
||||
self.assertAlmostEqual(0.0, leaves)
|
||||
self.assertAlmostEqual(0.0, tclf.get_leaves())
|
||||
self.assertAlmostEqual(0.0, nodes)
|
||||
self.assertAlmostEqual(0.0, tclf.get_nodes())
|
||||
|
||||
def test_estimator_hyperparams(self):
|
||||
data = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user