mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 17:06:01 +00:00
#3 refactor unneeded code and new test
This commit is contained in:
@@ -68,14 +68,14 @@ class Snode:
|
||||
if len(classes) > 1:
|
||||
max_card = max(card)
|
||||
min_card = min(card)
|
||||
try:
|
||||
self._belief = max_card / (max_card + min_card)
|
||||
except ZeroDivisionError:
|
||||
self._belief = 0.0
|
||||
self._class = classes[card == max_card][0]
|
||||
self._belief = max_card / (max_card + min_card)
|
||||
else:
|
||||
self._belief = 1
|
||||
self._class = classes[0]
|
||||
try:
|
||||
self._class = classes[0]
|
||||
except IndexError:
|
||||
self._class = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.is_leaf():
|
||||
@@ -182,6 +182,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
if res.ndim == 1:
|
||||
return np.expand_dims(res, 1)
|
||||
elif res.shape[1] > 1:
|
||||
# remove multiclass info
|
||||
res = np.delete(res, slice(1, res.shape[1]), axis=1)
|
||||
return res
|
||||
|
||||
@@ -216,8 +217,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
:rtype: Stree
|
||||
"""
|
||||
# Check parameters are Ok.
|
||||
if type(y).__name__ == "np.ndarray":
|
||||
y = y.ravel()
|
||||
if self.C < 0:
|
||||
raise ValueError(
|
||||
f"Penalty term must be positive... got (C={self.C:f})"
|
||||
@@ -463,7 +462,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
# sklearn check
|
||||
check_is_fitted(self)
|
||||
|
||||
check_classification_targets(y)
|
||||
X, y = check_X_y(X, y)
|
||||
y_pred = self.predict(X).reshape(y.shape)
|
||||
# Compute accuracy for each possible representation
|
||||
y_type, y_true, y_pred = _check_targets(y, y_pred)
|
||||
|
Reference in New Issue
Block a user