diff --git a/stree/Strees.py b/stree/Strees.py index 0b02e39..e5ce526 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -68,14 +68,14 @@ class Snode: if len(classes) > 1: max_card = max(card) min_card = min(card) - try: - self._belief = max_card / (max_card + min_card) - except ZeroDivisionError: - self._belief = 0.0 self._class = classes[card == max_card][0] + self._belief = max_card / (max_card + min_card) else: self._belief = 1 - self._class = classes[0] + try: + self._class = classes[0] + except IndexError: + self._class = None def __str__(self) -> str: if self.is_leaf(): @@ -182,6 +182,7 @@ class Stree(BaseEstimator, ClassifierMixin): if res.ndim == 1: return np.expand_dims(res, 1) elif res.shape[1] > 1: + # remove multiclass info res = np.delete(res, slice(1, res.shape[1]), axis=1) return res @@ -216,8 +217,6 @@ class Stree(BaseEstimator, ClassifierMixin): :rtype: Stree """ # Check parameters are Ok. - if type(y).__name__ == "np.ndarray": - y = y.ravel() if self.C < 0: raise ValueError( f"Penalty term must be positive... got (C={self.C:f})" @@ -463,7 +462,8 @@ class Stree(BaseEstimator, ClassifierMixin): """ # sklearn check check_is_fitted(self) - + check_classification_targets(y) + X, y = check_X_y(X, y) y_pred = self.predict(X).reshape(y.shape) # Compute accuracy for each possible representation y_type, y_true, y_pred = _check_targets(y, y_pred) diff --git a/stree/tests/Strees_test.py b/stree/tests/Strees_test.py index fe3cd1f..8bda05b 100644 --- a/stree/tests/Strees_test.py +++ b/stree/tests/Strees_test.py @@ -256,6 +256,15 @@ class Stree_test(unittest.TestCase): self.assertIsNone(tcl_nosplit.tree_.get_down()) self.assertIsNone(tcl_nosplit.tree_.get_up()) + def test_muticlass_dataset(self): + for kernel in self._kernels: + clf = Stree(kernel=kernel, random_state=self._random_state) + px = [[1, 2], [3, 4], [5, 6]] + py = [1, 2, 3] + clf.fit(px, py) + self.assertEqual(1.0, clf.score(px, py)) + self.assertListEqual([1, 2, 3], clf.predict(px).tolist()) + class Snode_test(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -324,3 +333,8 @@ class Snode_test(unittest.TestCase): test.make_predictor() self.assertIsNone(test._class) self.assertEqual(0, test._belief) + + def test_make_predictor_on_leaf_bogus_data(self): + test = Snode(None, [1, 2, 3, 4], [], "test") + test.make_predictor() + self.assertIsNone(test._class)