#2 Refactor Stree & create Splitter

Add and test splitter parameter
This commit is contained in:
2020-06-15 00:22:57 +02:00
parent 502ee72799
commit c94bc068bd
7 changed files with 529 additions and 176 deletions

View File

@@ -204,13 +204,11 @@ class Stree_test(unittest.TestCase):
self.assertEqual(0, len(list(tcl)))
def test_min_samples_split(self):
tcl_split = Stree(min_samples_split=3)
tcl_nosplit = Stree(min_samples_split=4)
dataset = [[1], [2], [3]], [1, 1, 0]
tcl_split.fit(*dataset)
tcl_split = Stree(min_samples_split=3).fit(*dataset)
self.assertIsNotNone(tcl_split.tree_.get_down())
self.assertIsNotNone(tcl_split.tree_.get_up())
tcl_nosplit.fit(*dataset)
tcl_nosplit = Stree(min_samples_split=4).fit(*dataset)
self.assertIsNone(tcl_nosplit.tree_.get_down())
self.assertIsNone(tcl_nosplit.tree_.get_up())
@@ -265,37 +263,6 @@ class Stree_test(unittest.TestCase):
outcome = outcomes[name][f"{criteria} {kernel}"]
self.assertAlmostEqual(outcome, clf.score(px, py))
def test_min_distance(self):
clf = Stree()
data = np.array(
[
[-0.1, 0.2, -0.3],
[0.7, 0.01, -0.1],
[0.7, -0.9, 0.5],
[0.1, 0.2, 0.3],
]
)
expected = np.array([-0.1, 0.01, 0.5, 0.1])
computed = clf._min_distance(data, None)
self.assertEqual((4,), computed.shape)
self.assertListEqual(expected.tolist(), computed.tolist())
def test_max_samples(self):
clf = Stree()
data = np.array(
[
[-0.1, 0.2, -0.3],
[0.7, 0.01, -0.1],
[0.7, -0.9, 0.5],
[0.1, 0.2, 0.3],
]
)
expected = np.array([0.2, 0.01, -0.9, 0.2])
y = [1, 2, 1, 0]
computed = clf._max_samples(data, y)
self.assertEqual((4,), computed.shape)
self.assertListEqual(expected.tolist(), computed.tolist())
def test_max_features(self):
n_features = 16
expected_values = [
@@ -334,7 +301,9 @@ class Stree_test(unittest.TestCase):
for max_features, expected in expected_values:
clf.set_params(**dict(max_features=max_features))
clf.fit(dataset, y)
computed, indices = clf._get_subspace(dataset)
computed, indices = clf.splitter_.get_subspace(
dataset, y, clf.max_features_
)
self.assertListEqual(
dataset[:, indices].tolist(), computed.tolist()
)
@@ -345,22 +314,6 @@ class Stree_test(unittest.TestCase):
with self.assertRaises(ValueError):
clf.fit(*load_dataset())
def test_gini(self):
y = [0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
expected = 0.48
self.assertEqual(expected, Stree._gini(y))
clf = Stree(criterion="gini")
clf.fit(*load_dataset())
self.assertEqual(expected, clf.criterion_function_(y))
def test_entropy(self):
y = [0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
expected = 0.9709505944546686
self.assertAlmostEqual(expected, Stree._entropy(y))
clf = Stree(criterion="entropy")
clf.fit(*load_dataset())
self.assertEqual(expected, clf.criterion_function_(y))
def test_predict_feature_dimensions(self):
X = np.random.rand(10, 5)
y = np.random.randint(0, 2, 10)
@@ -374,3 +327,8 @@ class Stree_test(unittest.TestCase):
clf = Stree(random_state=self._random_state, max_features=2)
clf.fit(X, y)
self.assertAlmostEqual(0.9426666666666667, clf.score(X, y))
def test_bogus_splitter_parameter(self):
clf = Stree(splitter="duck")
with self.assertRaises(ValueError):
clf.fit(*load_dataset())