diff --git a/stree/Splitter.py b/stree/Splitter.py index 94acf8b..b60f6cd 100644 --- a/stree/Splitter.py +++ b/stree/Splitter.py @@ -135,7 +135,7 @@ class Snode: if not self.is_leaf(): return classes, card = np.unique(self._y, return_counts=True) - self._proba = np.zeros((num_classes,)) + self._proba = np.zeros((num_classes,), dtype=np.int64) for c, n in zip(classes, card): self._proba[c] = n try: diff --git a/stree/Strees.py b/stree/Strees.py index d45b596..1f93252 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -367,28 +367,66 @@ class Stree(BaseEstimator, ClassifierMixin): ) ) - @staticmethod - def _reorder_results(y: np.array, indices: np.array) -> np.array: - """Reorder an array based on the array of indices passed + def __predict_class(self, X: np.array) -> np.array: + def compute_prediction(xp, indices, node): + if xp is None: + return + if node.is_leaf(): + # set a class for indices + result[indices] = node._proba + return + self.splitter_.partition(xp, node, train=False) + x_u, x_d = self.splitter_.part(xp) + i_u, i_d = self.splitter_.part(indices) + compute_prediction(x_u, i_u, node.get_up()) + compute_prediction(x_d, i_d, node.get_down()) + + # setup prediction & make it happen + result = np.zeros((X.shape[0], self.n_classes_)) + indices = np.arange(X.shape[0]) + compute_prediction(X, indices, self.tree_) + return result + + def check_predict(self, X) -> np.array: + check_is_fitted(self, ["tree_"]) + # Input validation + X = check_array(X) + if X.shape[1] != self.n_features_: + raise ValueError( + f"Expected {self.n_features_} features but got " + f"({X.shape[1]})" + ) + return X + + def predict_proba(self, X: np.array) -> np.array: + """Predict class probabilities of the input samples X. + + The predicted class probability is the fraction of samples of the same + class in a leaf. Parameters ---------- - y : np.array - data untidy - indices : np.array - indices used to set order + X : dataset of samples. Returns ------- - np.array - array y ordered + proba : array of shape (n_samples, n_classes) + The class probabilities of the input samples. + + Raises + ------ + ValueError + if dataset with inconsistent number of features + NotFittedError + if model is not fitted """ - # return array of same type given in y - y_ordered = y.copy() - indices = indices.astype(int) - for i, index in enumerate(indices): - y_ordered[index] = y[i] - return y_ordered + + X = self.check_predict(X) + # return # of samples of each class in leaf node + values = self.__predict_class(X) + normalizer = values.sum(axis=1)[:, np.newaxis] + normalizer[normalizer == 0.0] = 1.0 + return values / normalizer def predict(self, X: np.array) -> np.array: """Predict labels for each sample in dataset passed @@ -410,40 +448,8 @@ class Stree(BaseEstimator, ClassifierMixin): NotFittedError if model is not fitted """ - - def predict_class( - xp: np.array, indices: np.array, node: Snode - ) -> np.array: - if xp is None: - return [], [] - if node.is_leaf(): - # set a class for every sample in dataset - prediction = np.full((xp.shape[0], 1), node._class) - return prediction, indices - self.splitter_.partition(xp, node, train=False) - x_u, x_d = self.splitter_.part(xp) - i_u, i_d = self.splitter_.part(indices) - prx_u, prin_u = predict_class(x_u, i_u, node.get_up()) - prx_d, prin_d = predict_class(x_d, i_d, node.get_down()) - return np.append(prx_u, prx_d), np.append(prin_u, prin_d) - - # sklearn check - check_is_fitted(self, ["tree_"]) - # Input validation - X = check_array(X) - if X.shape[1] != self.n_features_: - raise ValueError( - f"Expected {self.n_features_} features but got " - f"({X.shape[1]})" - ) - # setup prediction & make it happen - indices = np.arange(X.shape[0]) - result = ( - self._reorder_results(*predict_class(X, indices, self.tree_)) - .astype(int) - .ravel() - ) - return self.classes_[result] + X = self.check_predict(X) + return self.classes_[np.argmax(self.__predict_class(X), axis=1)] def nodes_leaves(self) -> tuple: """Compute the number of nodes and leaves in the built tree diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index cb9be17..3a4247a 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -695,7 +695,7 @@ class Stree_test(unittest.TestCase): ) expected_tail = ( ' [shape=box style=filled label="class=1 impurity=0.000 ' - 'counts=[0. 1. 0.]"];\n}\n' + 'counts=[0 1 0]"];\n}\n' ) self.assertEqual(clf.graph(), expected_head + "}\n") clf.fit(X, y) @@ -715,7 +715,7 @@ class Stree_test(unittest.TestCase): ) expected_tail = ( ' [shape=box style=filled label="class=1 impurity=0.000 ' - 'counts=[0. 1. 0.]"];\n}\n' + 'counts=[0 1 0]"];\n}\n' ) self.assertEqual(clf.graph("Sample title"), expected_head + "}\n") clf.fit(X, y)