diff --git a/stree/Strees.py b/stree/Strees.py index ab35b4b..7c29a2e 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -635,6 +635,7 @@ class Stree(BaseEstimator, ClassifierMixin): X = X[~indices_zero, :] y = y[~indices_zero] sample_weight = sample_weight[~indices_zero] + self.depth_ = max(depth, self.depth_) if np.unique(y).shape[0] == 1: # only 1 class => pure dataset return Snode( @@ -652,7 +653,6 @@ class Stree(BaseEstimator, ClassifierMixin): clf.fit(Xs, y, sample_weight=sample_weight) impurity = self.splitter_.partition_impurity(y) node = Snode(clf, X, y, features, impurity, title, sample_weight) - self.depth_ = max(depth, self.depth_) self.splitter_.partition(X, node, True) X_U, X_D = self.splitter_.part(X) y_u, y_d = self.splitter_.part(y) @@ -668,8 +668,14 @@ class Stree(BaseEstimator, ClassifierMixin): title=title + ", ", weight=sample_weight, ) - node.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up")) - node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down")) + node.set_up( + self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})") + ) + node.set_down( + self.train( + X_D, y_d, sw_d, depth + 1, title + f" - Down({depth+1})" + ) + ) return node def _build_predictor(self): diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 65afeef..319f60f 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -103,20 +103,20 @@ class Stree_test(unittest.TestCase): def test_iterator_and_str(self): """Check preorder iterator""" expected = [ - "root feaures=(0, 1, 2) impurity=1.0000 counts=(array([0, 1]), arr" - "ay([750, 750]))", - "root - Down, - Leaf class=0 belief= 0.928297 impurity=0.37" - "22 counts=(array([0, 1]), array([725, 56]))", - "root - Up feaures=(0, 1, 2) impurity=0.2178 counts=(array([0, 1])" - ", array([ 25, 694]))", - "root - Up - Down feaures=(0, 1, 2) impurity=0.8454 counts=(array(" - "[0, 1]), array([8, 3]))", - "root - Up - Down - Down, - Leaf class=0 belief= 1.000000 i" - "mpurity=0.0000 counts=(array([0]), array([7]))", - "root - Up - Down - Up, - Leaf class=1 belief= 0.750000 imp" - "urity=0.8113 counts=(array([0, 1]), array([1, 3]))", - "root - Up - Up, - Leaf class=1 belief= 0.975989 impurity=0" - ".1634 counts=(array([0, 1]), array([ 17, 691]))", + "root feaures=(0, 1, 2) impurity=1.0000 counts=(array([0, 1]), " + "array([750, 750]))", + "root - Down(2), - Leaf class=0 belief= 0.928297 impurity=" + "0.3722 counts=(array([0, 1]), array([725, 56]))", + "root - Up(2) feaures=(0, 1, 2) impurity=0.2178 counts=(array([0, " + "1]), array([ 25, 694]))", + "root - Up(2) - Down(3) feaures=(0, 1, 2) impurity=0.8454 counts=" + "(array([0, 1]), array([8, 3]))", + "root - Up(2) - Down(3) - Down(4), - Leaf class=0 belief= " + "1.000000 impurity=0.0000 counts=(array([0]), array([7]))", + "root - Up(2) - Down(3) - Up(4), - Leaf class=1 belief= " + "0.750000 impurity=0.8113 counts=(array([0, 1]), array([1, 3]))", + "root - Up(2) - Up(3), - Leaf class=1 belief= 0.975989 " + "impurity=0.1634 counts=(array([0, 1]), array([ 17, 691]))", ] computed = [] expected_string = "" @@ -439,3 +439,18 @@ class Stree_test(unittest.TestCase): self.assertEqual(model1.score(X, y), 1) self.assertAlmostEqual(model2.score(X, y), 0.66666667) self.assertEqual(model2.score(X, y, w), 1) + + def test_depth(self): + X, y = load_dataset( + random_state=self._random_state, + n_classes=3, + n_features=5, + n_samples=1500, + ) + clf = Stree(random_state=self._random_state) + clf.fit(X, y) + self.assertEqual(6, clf.depth_) + X, y = load_wine(return_X_y=True) + clf = Stree(random_state=self._random_state) + clf.fit(X, y) + self.assertEqual(7, clf.depth_)