From c36f685263a1309426252b8e6a239dbf518f8dd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Thu, 8 Apr 2021 08:27:31 +0200 Subject: [PATCH] Fix unintended nested if in partition --- stree/Strees.py | 6 +++--- stree/tests/Stree_test.py | 14 +++++++------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/stree/Strees.py b/stree/Strees.py index fa0cbd9..0cfaa66 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -460,9 +460,9 @@ class Splitter: # in predcit time just use the column computed in train time # is taking the classifier of class col = node.get_partition_column() - if col == -1: - # No partition is producing information gain - data = np.ones(data.shape) + if col == -1: + # No partition is producing information gain + data = np.ones(data.shape) data = data[:, col] self._up = data > 0 diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 5deb372..5a3eba1 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -198,10 +198,10 @@ class Stree_test(unittest.TestCase): "Synt": { "max_samples linear": 0.9606666666666667, "max_samples rbf": 0.7133333333333334, - "max_samples poly": 0.49066666666666664, + "max_samples poly": 0.618, "impurity linear": 0.9606666666666667, "impurity rbf": 0.7133333333333334, - "impurity poly": 0.49066666666666664, + "impurity poly": 0.618, }, "Iris": { "max_samples linear": 1.0, @@ -378,7 +378,7 @@ class Stree_test(unittest.TestCase): n_samples=500, ) clf = Stree(kernel="rbf", random_state=self._random_state) - self.assertEqual(0.824, clf.fit(X, y).score(X, y)) + self.assertEqual(0.768, clf.fit(X, y).score(X, y)) X, y = load_wine(return_X_y=True) self.assertEqual(0.6741573033707865, clf.fit(X, y).score(X, y)) @@ -406,7 +406,7 @@ class Stree_test(unittest.TestCase): clf = Stree(kernel="linear", random_state=self._random_state) self.assertEqual(0.9533333333333334, clf.fit(X, y).score(X, y)) X, y = load_wine(return_X_y=True) - self.assertEqual(0.9550561797752809, clf.fit(X, y).score(X, y)) + self.assertEqual(0.9831460674157303, clf.fit(X, y).score(X, y)) def test_zero_all_sample_weights(self): X, y = load_dataset(self._random_state) @@ -453,7 +453,7 @@ class Stree_test(unittest.TestCase): X, y = load_wine(return_X_y=True) clf = Stree(random_state=self._random_state) clf.fit(X, y) - self.assertEqual(7, clf.depth_) + self.assertEqual(4, clf.depth_) def test_nodes_leaves(self): X, y = load_dataset( @@ -471,5 +471,5 @@ class Stree_test(unittest.TestCase): clf = Stree(random_state=self._random_state) clf.fit(X, y) nodes, leaves = clf.nodes_leaves() - self.assertEqual(8, nodes) - self.assertEquals(9, leaves) + self.assertEqual(4, nodes) + self.assertEquals(5, leaves)