Change entropy function with scipy (#38)

This commit is contained in:
2021-11-01 18:41:15 +01:00
parent e5d49132ec
commit 7a625eee09

View File

@@ -478,18 +478,6 @@ class Splitter:
@staticmethod
def _entropy(y: np.array) -> float:
"""Compute entropy of a labels set
Parameters
----------
y : np.array
set of labels
Returns
-------
float
entropy
"""
n_labels = len(y)
if n_labels <= 1:
return 0
@@ -497,13 +485,10 @@ class Splitter:
proportions = counts / n_labels
n_classes = np.count_nonzero(proportions)
if n_classes <= 1:
return 0
entropy = 0.0
# Compute standard entropy.
for prop in proportions:
if prop != 0.0:
entropy -= prop * log(prop, n_classes)
return entropy
return 0.0
from scipy.stats import entropy
return entropy(y, base=n_classes)
def information_gain(
self, labels: np.array, labels_up: np.array, labels_dn: np.array