mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 08:12:06 +00:00
Fix depth/leaves/nodes no longer return average
This commit is contained in:
parent
52d1095161
commit
02e75b3c3e
17
odte/Odte.py
17
odte/Odte.py
@ -91,7 +91,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
||||
return self
|
||||
|
||||
def _compute_metrics(self) -> None:
|
||||
tdepth = tnodes = tleaves = 0.0
|
||||
tdepth = tnodes = tleaves = 0
|
||||
for estimator in self.estimators_:
|
||||
if hasattr(estimator, "nodes_leaves"):
|
||||
nodes, leaves = estimator.nodes_leaves()
|
||||
@ -99,9 +99,12 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
||||
tdepth += depth
|
||||
tnodes += nodes
|
||||
tleaves += leaves
|
||||
self.depth_ = tdepth / self.n_estimators
|
||||
self.leaves_ = tleaves / self.n_estimators
|
||||
self.nodes_ = tnodes / self.n_estimators
|
||||
# self.depth_ = tdepth / self.n_estimators
|
||||
# self.leaves_ = tleaves / self.n_estimators
|
||||
# self.nodes_ = tnodes / self.n_estimators
|
||||
self.depth_ = tdepth
|
||||
self.leaves_ = tleaves
|
||||
self.nodes_ = tnodes
|
||||
|
||||
def _train(
|
||||
self, X: np.ndarray, y: np.ndarray, weights: np.ndarray
|
||||
@ -251,15 +254,15 @@ class Odte(BaseEnsemble, ClassifierMixin):
|
||||
result[i, predictions[i]] += 1
|
||||
return result / self.n_estimators
|
||||
|
||||
def get_nodes(self) ->int:
|
||||
def get_nodes(self) -> int:
|
||||
check_is_fitted(self, "estimators_")
|
||||
return self.nodes_
|
||||
|
||||
def get_leaves(self) ->int:
|
||||
def get_leaves(self) -> int:
|
||||
check_is_fitted(self, "estimators_")
|
||||
return self.leaves_
|
||||
|
||||
def get_depth(self) ->int:
|
||||
def get_depth(self) -> int:
|
||||
check_is_fitted(self, "estimators_")
|
||||
return self.depth_
|
||||
|
||||
|
@ -1 +1 @@
|
||||
__version__ = "0.3.5"
|
||||
__version__ = "0.3.6"
|
||||
|
@ -214,17 +214,17 @@ class Odte_test(unittest.TestCase):
|
||||
tclf.fit(X, y)
|
||||
tclf_p.fit(X, y)
|
||||
for clf in [tclf, tclf_p]:
|
||||
self.assertAlmostEqual(5.8, clf.depth_)
|
||||
self.assertAlmostEqual(5.8, clf.get_depth())
|
||||
self.assertAlmostEqual(9.4, clf.leaves_)
|
||||
self.assertAlmostEqual(9.4, clf.get_leaves())
|
||||
self.assertAlmostEqual(17.8, clf.nodes_)
|
||||
self.assertAlmostEqual(17.8, clf.get_nodes())
|
||||
self.assertEqual(29, clf.depth_)
|
||||
self.assertEqual(29, clf.get_depth())
|
||||
self.assertEqual(47, clf.leaves_)
|
||||
self.assertEqual(47, clf.get_leaves())
|
||||
self.assertEqual(89, clf.nodes_)
|
||||
self.assertEqual(89, clf.get_nodes())
|
||||
nodes, leaves = clf.nodes_leaves()
|
||||
self.assertAlmostEqual(9.4, leaves)
|
||||
self.assertAlmostEqual(9.4, clf.get_leaves())
|
||||
self.assertAlmostEqual(17.8, nodes)
|
||||
self.assertAlmostEqual(17.8, clf.get_nodes())
|
||||
self.assertEqual(47, leaves)
|
||||
self.assertEqual(47, clf.get_leaves())
|
||||
self.assertEqual(89, nodes)
|
||||
self.assertEqual(89, clf.get_nodes())
|
||||
|
||||
def test_nodes_leaves_SVC(self):
|
||||
tclf = Odte(
|
||||
|
Loading…
x
Reference in New Issue
Block a user