diff --git a/stree/Strees.py b/stree/Strees.py
index fa0cbd9..0cfaa66 100644
--- a/stree/Strees.py
+++ b/stree/Strees.py
@@ -460,9 +460,9 @@ class Splitter:
# in predcit time just use the column computed in train time
# is taking the classifier of class
col = node.get_partition_column()
- if col == -1:
- # No partition is producing information gain
- data = np.ones(data.shape)
+ if col == -1:
+ # No partition is producing information gain
+ data = np.ones(data.shape)
data = data[:, col]
self._up = data > 0
diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py
index 5deb372..5a3eba1 100644
--- a/stree/tests/Stree_test.py
+++ b/stree/tests/Stree_test.py
@@ -198,10 +198,10 @@ class Stree_test(unittest.TestCase):
"Synt": {
"max_samples linear": 0.9606666666666667,
"max_samples rbf": 0.7133333333333334,
- "max_samples poly": 0.49066666666666664,
+ "max_samples poly": 0.618,
"impurity linear": 0.9606666666666667,
"impurity rbf": 0.7133333333333334,
- "impurity poly": 0.49066666666666664,
+ "impurity poly": 0.618,
},
"Iris": {
"max_samples linear": 1.0,
@@ -378,7 +378,7 @@ class Stree_test(unittest.TestCase):
n_samples=500,
)
clf = Stree(kernel="rbf", random_state=self._random_state)
- self.assertEqual(0.824, clf.fit(X, y).score(X, y))
+ self.assertEqual(0.768, clf.fit(X, y).score(X, y))
X, y = load_wine(return_X_y=True)
self.assertEqual(0.6741573033707865, clf.fit(X, y).score(X, y))
@@ -406,7 +406,7 @@ class Stree_test(unittest.TestCase):
clf = Stree(kernel="linear", random_state=self._random_state)
self.assertEqual(0.9533333333333334, clf.fit(X, y).score(X, y))
X, y = load_wine(return_X_y=True)
- self.assertEqual(0.9550561797752809, clf.fit(X, y).score(X, y))
+ self.assertEqual(0.9831460674157303, clf.fit(X, y).score(X, y))
def test_zero_all_sample_weights(self):
X, y = load_dataset(self._random_state)
@@ -453,7 +453,7 @@ class Stree_test(unittest.TestCase):
X, y = load_wine(return_X_y=True)
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
- self.assertEqual(7, clf.depth_)
+ self.assertEqual(4, clf.depth_)
def test_nodes_leaves(self):
X, y = load_dataset(
@@ -471,5 +471,5 @@ class Stree_test(unittest.TestCase):
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
nodes, leaves = clf.nodes_leaves()
- self.assertEqual(8, nodes)
- self.assertEquals(9, leaves)
+ self.assertEqual(4, nodes)
+ self.assertEquals(5, leaves)