mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 16:36:01 +00:00
#3 refactor unneeded code and new test
This commit is contained in:
@@ -68,14 +68,14 @@ class Snode:
|
|||||||
if len(classes) > 1:
|
if len(classes) > 1:
|
||||||
max_card = max(card)
|
max_card = max(card)
|
||||||
min_card = min(card)
|
min_card = min(card)
|
||||||
try:
|
|
||||||
self._belief = max_card / (max_card + min_card)
|
|
||||||
except ZeroDivisionError:
|
|
||||||
self._belief = 0.0
|
|
||||||
self._class = classes[card == max_card][0]
|
self._class = classes[card == max_card][0]
|
||||||
|
self._belief = max_card / (max_card + min_card)
|
||||||
else:
|
else:
|
||||||
self._belief = 1
|
self._belief = 1
|
||||||
self._class = classes[0]
|
try:
|
||||||
|
self._class = classes[0]
|
||||||
|
except IndexError:
|
||||||
|
self._class = None
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
if self.is_leaf():
|
if self.is_leaf():
|
||||||
@@ -182,6 +182,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
if res.ndim == 1:
|
if res.ndim == 1:
|
||||||
return np.expand_dims(res, 1)
|
return np.expand_dims(res, 1)
|
||||||
elif res.shape[1] > 1:
|
elif res.shape[1] > 1:
|
||||||
|
# remove multiclass info
|
||||||
res = np.delete(res, slice(1, res.shape[1]), axis=1)
|
res = np.delete(res, slice(1, res.shape[1]), axis=1)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@@ -216,8 +217,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
:rtype: Stree
|
:rtype: Stree
|
||||||
"""
|
"""
|
||||||
# Check parameters are Ok.
|
# Check parameters are Ok.
|
||||||
if type(y).__name__ == "np.ndarray":
|
|
||||||
y = y.ravel()
|
|
||||||
if self.C < 0:
|
if self.C < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Penalty term must be positive... got (C={self.C:f})"
|
f"Penalty term must be positive... got (C={self.C:f})"
|
||||||
@@ -463,7 +462,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
"""
|
"""
|
||||||
# sklearn check
|
# sklearn check
|
||||||
check_is_fitted(self)
|
check_is_fitted(self)
|
||||||
|
check_classification_targets(y)
|
||||||
|
X, y = check_X_y(X, y)
|
||||||
y_pred = self.predict(X).reshape(y.shape)
|
y_pred = self.predict(X).reshape(y.shape)
|
||||||
# Compute accuracy for each possible representation
|
# Compute accuracy for each possible representation
|
||||||
y_type, y_true, y_pred = _check_targets(y, y_pred)
|
y_type, y_true, y_pred = _check_targets(y, y_pred)
|
||||||
|
@@ -256,6 +256,15 @@ class Stree_test(unittest.TestCase):
|
|||||||
self.assertIsNone(tcl_nosplit.tree_.get_down())
|
self.assertIsNone(tcl_nosplit.tree_.get_down())
|
||||||
self.assertIsNone(tcl_nosplit.tree_.get_up())
|
self.assertIsNone(tcl_nosplit.tree_.get_up())
|
||||||
|
|
||||||
|
def test_muticlass_dataset(self):
|
||||||
|
for kernel in self._kernels:
|
||||||
|
clf = Stree(kernel=kernel, random_state=self._random_state)
|
||||||
|
px = [[1, 2], [3, 4], [5, 6]]
|
||||||
|
py = [1, 2, 3]
|
||||||
|
clf.fit(px, py)
|
||||||
|
self.assertEqual(1.0, clf.score(px, py))
|
||||||
|
self.assertListEqual([1, 2, 3], clf.predict(px).tolist())
|
||||||
|
|
||||||
|
|
||||||
class Snode_test(unittest.TestCase):
|
class Snode_test(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@@ -324,3 +333,8 @@ class Snode_test(unittest.TestCase):
|
|||||||
test.make_predictor()
|
test.make_predictor()
|
||||||
self.assertIsNone(test._class)
|
self.assertIsNone(test._class)
|
||||||
self.assertEqual(0, test._belief)
|
self.assertEqual(0, test._belief)
|
||||||
|
|
||||||
|
def test_make_predictor_on_leaf_bogus_data(self):
|
||||||
|
test = Snode(None, [1, 2, 3, 4], [], "test")
|
||||||
|
test.make_predictor()
|
||||||
|
self.assertIsNone(test._class)
|
||||||
|
Reference in New Issue
Block a user