mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
Change entropy function with scipy (#38)
This commit is contained in:
@@ -478,18 +478,6 @@ class Splitter:
|
||||
|
||||
@staticmethod
|
||||
def _entropy(y: np.array) -> float:
|
||||
"""Compute entropy of a labels set
|
||||
|
||||
Parameters
|
||||
----------
|
||||
y : np.array
|
||||
set of labels
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
entropy
|
||||
"""
|
||||
n_labels = len(y)
|
||||
if n_labels <= 1:
|
||||
return 0
|
||||
@@ -497,13 +485,10 @@ class Splitter:
|
||||
proportions = counts / n_labels
|
||||
n_classes = np.count_nonzero(proportions)
|
||||
if n_classes <= 1:
|
||||
return 0
|
||||
entropy = 0.0
|
||||
# Compute standard entropy.
|
||||
for prop in proportions:
|
||||
if prop != 0.0:
|
||||
entropy -= prop * log(prop, n_classes)
|
||||
return entropy
|
||||
return 0.0
|
||||
from scipy.stats import entropy
|
||||
|
||||
return entropy(y, base=n_classes)
|
||||
|
||||
def information_gain(
|
||||
self, labels: np.array, labels_up: np.array, labels_dn: np.array
|
||||
|
Reference in New Issue
Block a user