diff --git a/stree/Strees.py b/stree/Strees.py index c88e581..f47775a 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -284,9 +284,8 @@ class Splitter: :type data: np.array (m, n_classes) :param y: vector of labels (classes) :type y: np.array (m,) - :return: vector with the class assigned to each sample values - (can be 0, 1, ...) -1 if none produces information gain - :rtype: np.array shape (m,) + :return: column of dataset to be taken into account to split dataset + :rtype: int """ max_gain = 0 selected = -1 @@ -307,8 +306,8 @@ class Splitter: :type data: np.array (m, n_classes) :param y: vector of labels (classes) :type y: np.array (m,) - :return: vector with distances to hyperplane (can be positive or neg.) - :rtype: np.array shape (m,) + :return: column of dataset to be taken into account to split dataset + :rtype: int """ # select the class with max number of samples _, samples = np.unique(y, return_counts=True)