mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 23:46:02 +00:00
#3 refactor unneeded code and new test
This commit is contained in:
@@ -68,14 +68,14 @@ class Snode:
|
||||
if len(classes) > 1:
|
||||
max_card = max(card)
|
||||
min_card = min(card)
|
||||
try:
|
||||
self._belief = max_card / (max_card + min_card)
|
||||
except ZeroDivisionError:
|
||||
self._belief = 0.0
|
||||
self._class = classes[card == max_card][0]
|
||||
self._belief = max_card / (max_card + min_card)
|
||||
else:
|
||||
self._belief = 1
|
||||
self._class = classes[0]
|
||||
try:
|
||||
self._class = classes[0]
|
||||
except IndexError:
|
||||
self._class = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.is_leaf():
|
||||
@@ -182,6 +182,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
if res.ndim == 1:
|
||||
return np.expand_dims(res, 1)
|
||||
elif res.shape[1] > 1:
|
||||
# remove multiclass info
|
||||
res = np.delete(res, slice(1, res.shape[1]), axis=1)
|
||||
return res
|
||||
|
||||
@@ -216,8 +217,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
:rtype: Stree
|
||||
"""
|
||||
# Check parameters are Ok.
|
||||
if type(y).__name__ == "np.ndarray":
|
||||
y = y.ravel()
|
||||
if self.C < 0:
|
||||
raise ValueError(
|
||||
f"Penalty term must be positive... got (C={self.C:f})"
|
||||
@@ -463,7 +462,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
# sklearn check
|
||||
check_is_fitted(self)
|
||||
|
||||
check_classification_targets(y)
|
||||
X, y = check_X_y(X, y)
|
||||
y_pred = self.predict(X).reshape(y.shape)
|
||||
# Compute accuracy for each possible representation
|
||||
y_type, y_true, y_pred = _check_targets(y, y_pred)
|
||||
|
@@ -256,6 +256,15 @@ class Stree_test(unittest.TestCase):
|
||||
self.assertIsNone(tcl_nosplit.tree_.get_down())
|
||||
self.assertIsNone(tcl_nosplit.tree_.get_up())
|
||||
|
||||
def test_muticlass_dataset(self):
|
||||
for kernel in self._kernels:
|
||||
clf = Stree(kernel=kernel, random_state=self._random_state)
|
||||
px = [[1, 2], [3, 4], [5, 6]]
|
||||
py = [1, 2, 3]
|
||||
clf.fit(px, py)
|
||||
self.assertEqual(1.0, clf.score(px, py))
|
||||
self.assertListEqual([1, 2, 3], clf.predict(px).tolist())
|
||||
|
||||
|
||||
class Snode_test(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -324,3 +333,8 @@ class Snode_test(unittest.TestCase):
|
||||
test.make_predictor()
|
||||
self.assertIsNone(test._class)
|
||||
self.assertEqual(0, test._belief)
|
||||
|
||||
def test_make_predictor_on_leaf_bogus_data(self):
|
||||
test = Snode(None, [1, 2, 3, 4], [], "test")
|
||||
test.make_predictor()
|
||||
self.assertIsNone(test._class)
|
||||
|
Reference in New Issue
Block a user