mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
Fix unintended nested if in partition
This commit is contained in:
@@ -460,9 +460,9 @@ class Splitter:
|
||||
# in predcit time just use the column computed in train time
|
||||
# is taking the classifier of class <col>
|
||||
col = node.get_partition_column()
|
||||
if col == -1:
|
||||
# No partition is producing information gain
|
||||
data = np.ones(data.shape)
|
||||
if col == -1:
|
||||
# No partition is producing information gain
|
||||
data = np.ones(data.shape)
|
||||
data = data[:, col]
|
||||
self._up = data > 0
|
||||
|
||||
|
@@ -198,10 +198,10 @@ class Stree_test(unittest.TestCase):
|
||||
"Synt": {
|
||||
"max_samples linear": 0.9606666666666667,
|
||||
"max_samples rbf": 0.7133333333333334,
|
||||
"max_samples poly": 0.49066666666666664,
|
||||
"max_samples poly": 0.618,
|
||||
"impurity linear": 0.9606666666666667,
|
||||
"impurity rbf": 0.7133333333333334,
|
||||
"impurity poly": 0.49066666666666664,
|
||||
"impurity poly": 0.618,
|
||||
},
|
||||
"Iris": {
|
||||
"max_samples linear": 1.0,
|
||||
@@ -378,7 +378,7 @@ class Stree_test(unittest.TestCase):
|
||||
n_samples=500,
|
||||
)
|
||||
clf = Stree(kernel="rbf", random_state=self._random_state)
|
||||
self.assertEqual(0.824, clf.fit(X, y).score(X, y))
|
||||
self.assertEqual(0.768, clf.fit(X, y).score(X, y))
|
||||
X, y = load_wine(return_X_y=True)
|
||||
self.assertEqual(0.6741573033707865, clf.fit(X, y).score(X, y))
|
||||
|
||||
@@ -406,7 +406,7 @@ class Stree_test(unittest.TestCase):
|
||||
clf = Stree(kernel="linear", random_state=self._random_state)
|
||||
self.assertEqual(0.9533333333333334, clf.fit(X, y).score(X, y))
|
||||
X, y = load_wine(return_X_y=True)
|
||||
self.assertEqual(0.9550561797752809, clf.fit(X, y).score(X, y))
|
||||
self.assertEqual(0.9831460674157303, clf.fit(X, y).score(X, y))
|
||||
|
||||
def test_zero_all_sample_weights(self):
|
||||
X, y = load_dataset(self._random_state)
|
||||
@@ -453,7 +453,7 @@ class Stree_test(unittest.TestCase):
|
||||
X, y = load_wine(return_X_y=True)
|
||||
clf = Stree(random_state=self._random_state)
|
||||
clf.fit(X, y)
|
||||
self.assertEqual(7, clf.depth_)
|
||||
self.assertEqual(4, clf.depth_)
|
||||
|
||||
def test_nodes_leaves(self):
|
||||
X, y = load_dataset(
|
||||
@@ -471,5 +471,5 @@ class Stree_test(unittest.TestCase):
|
||||
clf = Stree(random_state=self._random_state)
|
||||
clf.fit(X, y)
|
||||
nodes, leaves = clf.nodes_leaves()
|
||||
self.assertEqual(8, nodes)
|
||||
self.assertEquals(9, leaves)
|
||||
self.assertEqual(4, nodes)
|
||||
self.assertEquals(5, leaves)
|
||||
|
Reference in New Issue
Block a user