Fix unintended nested if in partition

This commit is contained in:
2021-04-08 08:27:31 +02:00
parent 0f89b044f1
commit c36f685263
2 changed files with 10 additions and 10 deletions

View File

@@ -460,9 +460,9 @@ class Splitter:
# in predcit time just use the column computed in train time
# is taking the classifier of class <col>
col = node.get_partition_column()
if col == -1:
# No partition is producing information gain
data = np.ones(data.shape)
if col == -1:
# No partition is producing information gain
data = np.ones(data.shape)
data = data[:, col]
self._up = data > 0

View File

@@ -198,10 +198,10 @@ class Stree_test(unittest.TestCase):
"Synt": {
"max_samples linear": 0.9606666666666667,
"max_samples rbf": 0.7133333333333334,
"max_samples poly": 0.49066666666666664,
"max_samples poly": 0.618,
"impurity linear": 0.9606666666666667,
"impurity rbf": 0.7133333333333334,
"impurity poly": 0.49066666666666664,
"impurity poly": 0.618,
},
"Iris": {
"max_samples linear": 1.0,
@@ -378,7 +378,7 @@ class Stree_test(unittest.TestCase):
n_samples=500,
)
clf = Stree(kernel="rbf", random_state=self._random_state)
self.assertEqual(0.824, clf.fit(X, y).score(X, y))
self.assertEqual(0.768, clf.fit(X, y).score(X, y))
X, y = load_wine(return_X_y=True)
self.assertEqual(0.6741573033707865, clf.fit(X, y).score(X, y))
@@ -406,7 +406,7 @@ class Stree_test(unittest.TestCase):
clf = Stree(kernel="linear", random_state=self._random_state)
self.assertEqual(0.9533333333333334, clf.fit(X, y).score(X, y))
X, y = load_wine(return_X_y=True)
self.assertEqual(0.9550561797752809, clf.fit(X, y).score(X, y))
self.assertEqual(0.9831460674157303, clf.fit(X, y).score(X, y))
def test_zero_all_sample_weights(self):
X, y = load_dataset(self._random_state)
@@ -453,7 +453,7 @@ class Stree_test(unittest.TestCase):
X, y = load_wine(return_X_y=True)
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
self.assertEqual(7, clf.depth_)
self.assertEqual(4, clf.depth_)
def test_nodes_leaves(self):
X, y = load_dataset(
@@ -471,5 +471,5 @@ class Stree_test(unittest.TestCase):
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
nodes, leaves = clf.nodes_leaves()
self.assertEqual(8, nodes)
self.assertEquals(9, leaves)
self.assertEqual(4, nodes)
self.assertEquals(5, leaves)