Fix some tests

This commit is contained in:
2021-06-01 12:50:35 +02:00
parent 794374fe8c
commit b15a059b1d
3 changed files with 31 additions and 27 deletions

View File

@@ -47,12 +47,12 @@ class Metrics:
return d[:, -1] # returns the distance to the kth nearest neighbor
@staticmethod
def differential_entropy(X, k=1):
def differential_entropy(x, k=1):
"""Returns the entropy of the X.
Parameters
===========
X : array-like, shape (n_samples, n_features)
x : array-like, shape (n_samples, n_features)
The data the entropy of which is computed
k : int, optional
number of nearest neighbors for density estimation
@@ -66,11 +66,11 @@ class Metrics:
Kraskov A, Stogbauer H, Grassberger P. (2004). Estimating mutual
information. Phys Rev E 69(6 Pt 2):066138.
"""
if X.ndim == 1:
X = X.reshape(-1, 1)
if x.ndim == 1:
x = x.reshape(-1, 1)
# Distance to kth nearest neighbor
r = Metrics._nearest_distances(X, k) # squared distances
n, d = X.shape
r = Metrics._nearest_distances(x, k) # squared distances
n, d = x.shape
volume_unit_ball = (np.pi ** (0.5 * d)) / gamma(0.5 * d + 1)
"""
F. Perez-Cruz, (2008). Estimation of Information Theoretic Measures
@@ -79,7 +79,7 @@ class Metrics:
return d*mean(log(r))+log(volume_unit_ball)+log(n-1)-log(k)
"""
return (
d * np.mean(np.log(r + np.finfo(X.dtype).eps))
d * np.mean(np.log(r + np.finfo(x.dtype).eps))
+ np.log(volume_unit_ball)
+ psi(n)
- psi(k)