#3 refactor unneeded code and new test

This commit is contained in:
2020-06-09 13:01:01 +02:00
parent 5c31c2b2a5
commit 286a91a3d7
2 changed files with 22 additions and 8 deletions

View File

@@ -68,14 +68,14 @@ class Snode:
if len(classes) > 1:
max_card = max(card)
min_card = min(card)
try:
self._belief = max_card / (max_card + min_card)
except ZeroDivisionError:
self._belief = 0.0
self._class = classes[card == max_card][0]
self._belief = max_card / (max_card + min_card)
else:
self._belief = 1
self._class = classes[0]
try:
self._class = classes[0]
except IndexError:
self._class = None
def __str__(self) -> str:
if self.is_leaf():
@@ -182,6 +182,7 @@ class Stree(BaseEstimator, ClassifierMixin):
if res.ndim == 1:
return np.expand_dims(res, 1)
elif res.shape[1] > 1:
# remove multiclass info
res = np.delete(res, slice(1, res.shape[1]), axis=1)
return res
@@ -216,8 +217,6 @@ class Stree(BaseEstimator, ClassifierMixin):
:rtype: Stree
"""
# Check parameters are Ok.
if type(y).__name__ == "np.ndarray":
y = y.ravel()
if self.C < 0:
raise ValueError(
f"Penalty term must be positive... got (C={self.C:f})"
@@ -463,7 +462,8 @@ class Stree(BaseEstimator, ClassifierMixin):
"""
# sklearn check
check_is_fitted(self)
check_classification_targets(y)
X, y = check_X_y(X, y)
y_pred = self.predict(X).reshape(y.shape)
# Compute accuracy for each possible representation
y_type, y_true, y_pred = _check_targets(y, y_pred)