# 2 - add max_features parameters

This commit is contained in:
2020-06-13 17:58:45 +02:00
parent 1bfe273a70
commit ae1c199e21
2 changed files with 104 additions and 8 deletions

View File

@@ -295,3 +295,47 @@ class Stree_test(unittest.TestCase):
computed = clf._max_samples(data, y)
self.assertEqual((4,), computed.shape)
self.assertListEqual(expected.tolist(), computed.tolist())
def test_max_features(self):
n_features = 16
expected_values = [
("auto", 4),
("log2", 4),
("sqrt", 4),
(0.5, 8),
(3, 3),
(None, 16),
]
clf = Stree()
clf.n_features_ = n_features
for max_features, expected in expected_values:
clf.set_params(**dict(max_features=max_features))
computed = clf._initialize_max_features()
self.assertEqual(expected, computed)
# Check bogus max_features
values = ["duck", -0.1, 0.0]
for max_features in values:
clf.set_params(**dict(max_features=max_features))
with self.assertRaises(ValueError):
_ = clf._initialize_max_features()
def test_get_subspaces(self):
dataset = np.random.random((10, 16))
y = np.random.randint(0, 2, 10)
expected_values = [
("auto", 4),
("log2", 4),
("sqrt", 4),
(0.5, 8),
(3, 3),
(None, 16),
]
clf = Stree()
for max_features, expected in expected_values:
clf.set_params(**dict(max_features=max_features))
clf.fit(dataset, y)
computed, indices = clf._get_subspace(dataset)
self.assertListEqual(
dataset[:, indices].tolist(), computed.tolist()
)
self.assertEqual(expected, len(indices))