#3 refactor unneeded code and new test

This commit is contained in:
2020-06-09 13:01:01 +02:00
parent 5c31c2b2a5
commit 286a91a3d7
2 changed files with 22 additions and 8 deletions

View File

@@ -68,14 +68,14 @@ class Snode:
if len(classes) > 1:
max_card = max(card)
min_card = min(card)
try:
self._belief = max_card / (max_card + min_card)
except ZeroDivisionError:
self._belief = 0.0
self._class = classes[card == max_card][0]
self._belief = max_card / (max_card + min_card)
else:
self._belief = 1
self._class = classes[0]
try:
self._class = classes[0]
except IndexError:
self._class = None
def __str__(self) -> str:
if self.is_leaf():
@@ -182,6 +182,7 @@ class Stree(BaseEstimator, ClassifierMixin):
if res.ndim == 1:
return np.expand_dims(res, 1)
elif res.shape[1] > 1:
# remove multiclass info
res = np.delete(res, slice(1, res.shape[1]), axis=1)
return res
@@ -216,8 +217,6 @@ class Stree(BaseEstimator, ClassifierMixin):
:rtype: Stree
"""
# Check parameters are Ok.
if type(y).__name__ == "np.ndarray":
y = y.ravel()
if self.C < 0:
raise ValueError(
f"Penalty term must be positive... got (C={self.C:f})"
@@ -463,7 +462,8 @@ class Stree(BaseEstimator, ClassifierMixin):
"""
# sklearn check
check_is_fitted(self)
check_classification_targets(y)
X, y = check_X_y(X, y)
y_pred = self.predict(X).reshape(y.shape)
# Compute accuracy for each possible representation
y_type, y_true, y_pred = _check_targets(y, y_pred)

View File

@@ -256,6 +256,15 @@ class Stree_test(unittest.TestCase):
self.assertIsNone(tcl_nosplit.tree_.get_down())
self.assertIsNone(tcl_nosplit.tree_.get_up())
def test_muticlass_dataset(self):
for kernel in self._kernels:
clf = Stree(kernel=kernel, random_state=self._random_state)
px = [[1, 2], [3, 4], [5, 6]]
py = [1, 2, 3]
clf.fit(px, py)
self.assertEqual(1.0, clf.score(px, py))
self.assertListEqual([1, 2, 3], clf.predict(px).tolist())
class Snode_test(unittest.TestCase):
def __init__(self, *args, **kwargs):
@@ -324,3 +333,8 @@ class Snode_test(unittest.TestCase):
test.make_predictor()
self.assertIsNone(test._class)
self.assertEqual(0, test._belief)
def test_make_predictor_on_leaf_bogus_data(self):
test = Snode(None, [1, 2, 3, 4], [], "test")
test.make_predictor()
self.assertIsNone(test._class)