#2 update benchmark notebook

This commit is contained in:
2020-06-15 10:33:51 +02:00
parent c94bc068bd
commit 736ab7ef20
2 changed files with 48 additions and 210 deletions

View File

@@ -193,8 +193,8 @@ class Splitter:
def information_gain(
self, labels_up: np.array, labels_dn: np.array
) -> float:
card_up = labels_up.shape[0]
card_dn = labels_dn.shape[0]
card_up = labels_up.shape[0] if labels_up is not None else 0
card_dn = labels_dn.shape[0] if labels_dn is not None else 0
samples = card_up + card_dn
up = card_up / samples * self.criterion_function(labels_up)
dn = card_dn / samples * self.criterion_function(labels_dn)