From 7a625eee09dd75caf899b45e5448e482db990637 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 1 Nov 2021 18:41:15 +0100 Subject: [PATCH] Change entropy function with scipy (#38) --- stree/Splitter.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/stree/Splitter.py b/stree/Splitter.py index acf737e..8e06ac3 100644 --- a/stree/Splitter.py +++ b/stree/Splitter.py @@ -478,18 +478,6 @@ class Splitter: @staticmethod def _entropy(y: np.array) -> float: - """Compute entropy of a labels set - - Parameters - ---------- - y : np.array - set of labels - - Returns - ------- - float - entropy - """ n_labels = len(y) if n_labels <= 1: return 0 @@ -497,13 +485,10 @@ class Splitter: proportions = counts / n_labels n_classes = np.count_nonzero(proportions) if n_classes <= 1: - return 0 - entropy = 0.0 - # Compute standard entropy. - for prop in proportions: - if prop != 0.0: - entropy -= prop * log(prop, n_classes) - return entropy + return 0.0 + from scipy.stats import entropy + + return entropy(y, base=n_classes) def information_gain( self, labels: np.array, labels_up: np.array, labels_dn: np.array