mirror of
https://github.com/Doctorado-ML/mufs.git
synced 2025-08-16 16:15:56 +00:00
Add max_features to selection
Add first approach to continuous variables
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||

|

|
||||||
|
[](https://www.codacy.com/gh/Doctorado-ML/mfs/dashboard?utm_source=github.com&utm_medium=referral&utm_content=Doctorado-ML/mfs&utm_campaign=Badge_Grade)
|
||||||
[](https://lgtm.com/projects/g/Doctorado-ML/mfs/context:python)
|
[](https://lgtm.com/projects/g/Doctorado-ML/mfs/context:python)
|
||||||
|
|
||||||
# MFS
|
# MFS
|
||||||
|
228
mfs/Metrics.py
Executable file
228
mfs/Metrics.py
Executable file
@@ -0,0 +1,228 @@
|
|||||||
|
from math import log
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from scipy.special import gamma, psi
|
||||||
|
from sklearn.neighbors import BallTree, KDTree, NearestNeighbors
|
||||||
|
from sklearn.feature_selection._mutual_info import _compute_mi
|
||||||
|
|
||||||
|
# from .entropy_estimators import mi, entropy as c_entropy
|
||||||
|
|
||||||
|
|
||||||
|
class Metrics:
|
||||||
|
@staticmethod
|
||||||
|
def information_gain_cont(x, y):
|
||||||
|
"""Measures the reduction in uncertainty about the value of y when the
|
||||||
|
value of X continuous is known (also called mutual information)
|
||||||
|
(https://www.sciencedirect.com/science/article/pii/S0020025519303603)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : np.array
|
||||||
|
values of the continuous variable
|
||||||
|
y : np.array
|
||||||
|
array of labels
|
||||||
|
base : int, optional
|
||||||
|
base of the logarithm, by default 2
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
Information gained
|
||||||
|
"""
|
||||||
|
return _compute_mi(
|
||||||
|
x, y, x_discrete=False, y_discrete=True, n_neighbors=3
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _nearest_distances(X, k=1):
|
||||||
|
"""
|
||||||
|
X = array(N,M)
|
||||||
|
N = number of points
|
||||||
|
M = number of dimensions
|
||||||
|
returns the distance to the kth nearest neighbor for every point in X
|
||||||
|
"""
|
||||||
|
knn = NearestNeighbors(n_neighbors=k + 1)
|
||||||
|
knn.fit(X)
|
||||||
|
d, _ = knn.kneighbors(X) # the first nearest neighbor is itself
|
||||||
|
return d[:, -1] # returns the distance to the kth nearest neighbor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def differential_entropy(X, k=1):
|
||||||
|
|
||||||
|
"""Returns the entropy of the X.
|
||||||
|
Parameters
|
||||||
|
===========
|
||||||
|
X : array-like, shape (n_samples, n_features)
|
||||||
|
The data the entropy of which is computed
|
||||||
|
k : int, optional
|
||||||
|
number of nearest neighbors for density estimation
|
||||||
|
Notes
|
||||||
|
======
|
||||||
|
Kozachenko, L. F. & Leonenko, N. N. 1987 Sample estimate of entropy
|
||||||
|
of a random vector. Probl. Inf. Transm. 23, 95-101.
|
||||||
|
See also: Evans, D. 2008 A computationally efficient estimator for
|
||||||
|
mutual information, Proc. R. Soc. A 464 (2093), 1203-1215.
|
||||||
|
and:
|
||||||
|
Kraskov A, Stogbauer H, Grassberger P. (2004). Estimating mutual
|
||||||
|
information. Phys Rev E 69(6 Pt 2):066138.
|
||||||
|
"""
|
||||||
|
if X.ndim == 1:
|
||||||
|
X = X.reshape(-1, 1)
|
||||||
|
# Distance to kth nearest neighbor
|
||||||
|
r = Metrics._nearest_distances(X, k) # squared distances
|
||||||
|
n, d = X.shape
|
||||||
|
volume_unit_ball = (np.pi ** (0.5 * d)) / gamma(0.5 * d + 1)
|
||||||
|
"""
|
||||||
|
F. Perez-Cruz, (2008). Estimation of Information Theoretic Measures
|
||||||
|
for Continuous Random Variables. Advances in Neural Information
|
||||||
|
Processing Systems 21 (NIPS). Vancouver (Canada), December.
|
||||||
|
return d*mean(log(r))+log(volume_unit_ball)+log(n-1)-log(k)
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
d * np.mean(np.log(r + np.finfo(X.dtype).eps))
|
||||||
|
+ np.log(volume_unit_ball)
|
||||||
|
+ psi(n)
|
||||||
|
- psi(k)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def conditional_differential_entropy(x, y):
|
||||||
|
"""quantifies the amount of information needed to describe the outcome
|
||||||
|
of Y discrete given that the value of X continuous is known
|
||||||
|
computes H(Y|X)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : np.array
|
||||||
|
values of the continuous variable
|
||||||
|
y : np.array
|
||||||
|
array of labels
|
||||||
|
base : int, optional
|
||||||
|
base of the logarithm, by default 2
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
conditional entropy of y given x
|
||||||
|
"""
|
||||||
|
xy = np.c_[x, y]
|
||||||
|
return Metrics.differential_entropy(xy) - Metrics.differential_entropy(
|
||||||
|
x
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def symmetrical_unc_continuous(x, y):
|
||||||
|
"""Compute symmetrical uncertainty. Using Greg Ver Steeg's npeet
|
||||||
|
https://github.com/gregversteeg/NPEET
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : np.array
|
||||||
|
values of the continuous variable
|
||||||
|
y : np.array
|
||||||
|
array of labels
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
symmetrical uncertainty
|
||||||
|
"""
|
||||||
|
|
||||||
|
return (
|
||||||
|
2.0
|
||||||
|
* Metrics.information_gain_cont(x, y)
|
||||||
|
/ (Metrics.differential_entropy(x) + Metrics.entropy(y))
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def symmetrical_uncertainty(x, y):
|
||||||
|
"""Compute symmetrical uncertainty. Normalize* information gain (mutual
|
||||||
|
information) with the entropies of the features in order to compensate
|
||||||
|
the bias due to high cardinality features. *Range [0, 1]
|
||||||
|
(https://www.sciencedirect.com/science/article/pii/S0020025519303603)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : np.array
|
||||||
|
values of the variable
|
||||||
|
y : np.array
|
||||||
|
array of labels
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
symmetrical uncertainty
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
2.0
|
||||||
|
* Metrics.information_gain(x, y)
|
||||||
|
/ (Metrics.entropy(x) + Metrics.entropy(y))
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def conditional_entropy(x, y, base=2):
|
||||||
|
"""quantifies the amount of information needed to describe the outcome
|
||||||
|
of Y given that the value of X is known
|
||||||
|
computes H(Y|X)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : np.array
|
||||||
|
values of the variable
|
||||||
|
y : np.array
|
||||||
|
array of labels
|
||||||
|
base : int, optional
|
||||||
|
base of the logarithm, by default 2
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
conditional entropy of y given x
|
||||||
|
"""
|
||||||
|
xy = np.c_[x, y]
|
||||||
|
return Metrics.entropy(xy, base) - Metrics.entropy(x, base)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def entropy(y, base=2):
|
||||||
|
"""measure of the uncertainty in predicting the value of y
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
y : np.array
|
||||||
|
array of labels
|
||||||
|
base : int, optional
|
||||||
|
base of the logarithm, by default 2
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
entropy of y
|
||||||
|
"""
|
||||||
|
_, count = np.unique(y, return_counts=True, axis=0)
|
||||||
|
proba = count.astype(float) / len(y)
|
||||||
|
proba = proba[proba > 0.0]
|
||||||
|
return np.sum(proba * np.log(1.0 / proba)) / log(base)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def information_gain(x, y, base=2):
|
||||||
|
"""Measures the reduction in uncertainty about the value of y when the
|
||||||
|
value of X is known (also called mutual information)
|
||||||
|
(https://www.sciencedirect.com/science/article/pii/S0020025519303603)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : np.array
|
||||||
|
values of the variable
|
||||||
|
y : np.array
|
||||||
|
array of labels
|
||||||
|
base : int, optional
|
||||||
|
base of the logarithm, by default 2
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
Information gained
|
||||||
|
"""
|
||||||
|
return Metrics.entropy(y, base) - Metrics.conditional_entropy(
|
||||||
|
x, y, base
|
||||||
|
)
|
189
mfs/Selection.py
189
mfs/Selection.py
@@ -1,102 +1,8 @@
|
|||||||
from math import log, sqrt
|
from math import sqrt
|
||||||
from sys import float_info
|
from sys import float_info
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from .Metrics import Metrics
|
||||||
|
|
||||||
class Metrics:
|
|
||||||
@staticmethod
|
|
||||||
def conditional_entropy(x, y, base=2):
|
|
||||||
"""quantifies the amount of information needed to describe the outcome
|
|
||||||
of Y given that the value of X is known
|
|
||||||
computes H(Y|X)
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x : np.array
|
|
||||||
values of the variable
|
|
||||||
y : np.array
|
|
||||||
array of labels
|
|
||||||
base : int, optional
|
|
||||||
base of the logarithm, by default 2
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
float
|
|
||||||
conditional entropy of y given x
|
|
||||||
"""
|
|
||||||
xy = np.c_[x, y]
|
|
||||||
return Metrics.entropy(xy, base) - Metrics.entropy(x, base)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def entropy(y, base=2):
|
|
||||||
"""measure of the uncertainty in predicting the value of y
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
y : np.array
|
|
||||||
array of labels
|
|
||||||
base : int, optional
|
|
||||||
base of the logarithm, by default 2
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
float
|
|
||||||
entropy of y
|
|
||||||
"""
|
|
||||||
_, count = np.unique(y, return_counts=True, axis=0)
|
|
||||||
proba = count.astype(float) / len(y)
|
|
||||||
proba = proba[proba > 0.0]
|
|
||||||
return np.sum(proba * np.log(1.0 / proba)) / log(base)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def information_gain(x, y, base=2):
|
|
||||||
"""Measures the reduction in uncertainty about the value of y when the
|
|
||||||
value of X is known (also called mutual information)
|
|
||||||
(https://www.sciencedirect.com/science/article/pii/S0020025519303603)
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x : np.array
|
|
||||||
values of the variable
|
|
||||||
y : np.array
|
|
||||||
array of labels
|
|
||||||
base : int, optional
|
|
||||||
base of the logarithm, by default 2
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
float
|
|
||||||
Information gained
|
|
||||||
"""
|
|
||||||
return Metrics.entropy(y, base) - Metrics.conditional_entropy(
|
|
||||||
x, y, base
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def symmetrical_uncertainty(x, y):
|
|
||||||
"""Compute symmetrical uncertainty. Normalize* information gain (mutual
|
|
||||||
information) with the entropies of the features in order to compensate
|
|
||||||
the bias due to high cardinality features. *Range [0, 1]
|
|
||||||
(https://www.sciencedirect.com/science/article/pii/S0020025519303603)
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x : np.array
|
|
||||||
values of the variable
|
|
||||||
y : np.array
|
|
||||||
array of labels
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
float
|
|
||||||
symmetrical uncertainty
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
2.0
|
|
||||||
* Metrics.information_gain(x, y)
|
|
||||||
/ (Metrics.entropy(x) + Metrics.entropy(y))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MFS:
|
class MFS:
|
||||||
@@ -116,18 +22,36 @@ class MFS:
|
|||||||
The maximum number of features to return
|
The maximum number of features to return
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_features):
|
def __init__(self, max_features=None, discrete=True):
|
||||||
self._initialize()
|
|
||||||
self._max_features = max_features
|
self._max_features = max_features
|
||||||
|
self._discrete = discrete
|
||||||
|
self.symmetrical_uncertainty = (
|
||||||
|
Metrics.symmetrical_uncertainty
|
||||||
|
if discrete
|
||||||
|
else Metrics.symmetrical_unc_continuous
|
||||||
|
)
|
||||||
|
self._fitted = False
|
||||||
|
|
||||||
def _initialize(self):
|
def _initialize(self, X, y):
|
||||||
"""Initialize the attributes so support multiple calls using same
|
"""Initialize the attributes so support multiple calls using same
|
||||||
object
|
object
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
X : np.array
|
||||||
|
array of features
|
||||||
|
y : np.array
|
||||||
|
vector of labels
|
||||||
"""
|
"""
|
||||||
|
self.X_ = X
|
||||||
|
self.y_ = y
|
||||||
|
if self._max_features is None:
|
||||||
|
self._max_features = X.shape[1]
|
||||||
self._result = None
|
self._result = None
|
||||||
self._scores = []
|
self._scores = []
|
||||||
self._su_labels = None
|
self._su_labels = None
|
||||||
self._su_features = {}
|
self._su_features = {}
|
||||||
|
self._fitted = True
|
||||||
|
|
||||||
def _compute_su_labels(self):
|
def _compute_su_labels(self):
|
||||||
"""Compute symmetrical uncertainty between each feature of the dataset
|
"""Compute symmetrical uncertainty between each feature of the dataset
|
||||||
@@ -142,7 +66,7 @@ class MFS:
|
|||||||
num_features = self.X_.shape[1]
|
num_features = self.X_.shape[1]
|
||||||
self._su_labels = np.zeros(num_features)
|
self._su_labels = np.zeros(num_features)
|
||||||
for col in range(num_features):
|
for col in range(num_features):
|
||||||
self._su_labels[col] = Metrics.symmetrical_uncertainty(
|
self._su_labels[col] = self.symmetrical_uncertainty(
|
||||||
self.X_[:, col], self.y_
|
self.X_[:, col], self.y_
|
||||||
)
|
)
|
||||||
return self._su_labels
|
return self._su_labels
|
||||||
@@ -166,7 +90,7 @@ class MFS:
|
|||||||
if (feature_a, feature_b) not in self._su_features:
|
if (feature_a, feature_b) not in self._su_features:
|
||||||
self._su_features[
|
self._su_features[
|
||||||
(feature_a, feature_b)
|
(feature_a, feature_b)
|
||||||
] = Metrics.symmetrical_uncertainty(
|
] = self.symmetrical_uncertainty(
|
||||||
self.X_[:, feature_a], self.X_[:, feature_b]
|
self.X_[:, feature_a], self.X_[:, feature_b]
|
||||||
)
|
)
|
||||||
return self._su_features[(feature_a, feature_b)]
|
return self._su_features[(feature_a, feature_b)]
|
||||||
@@ -210,9 +134,7 @@ class MFS:
|
|||||||
self
|
self
|
||||||
self
|
self
|
||||||
"""
|
"""
|
||||||
self._initialize()
|
self._initialize(X, y)
|
||||||
self.X_ = X
|
|
||||||
self.y_ = y
|
|
||||||
s_list = self._compute_su_labels()
|
s_list = self._compute_su_labels()
|
||||||
# Descending order
|
# Descending order
|
||||||
feature_order = (-s_list).argsort().tolist()
|
feature_order = (-s_list).argsort().tolist()
|
||||||
@@ -235,33 +157,36 @@ class MFS:
|
|||||||
candidates.append(feature_order[id_selected])
|
candidates.append(feature_order[id_selected])
|
||||||
self._scores.append(merit)
|
self._scores.append(merit)
|
||||||
del feature_order[id_selected]
|
del feature_order[id_selected]
|
||||||
if (
|
continue_condition = self._cfs_continue_condition(
|
||||||
len(feature_order) == 0
|
feature_order, candidates
|
||||||
or len(candidates) == self._max_features
|
)
|
||||||
):
|
|
||||||
# Force leaving the loop
|
|
||||||
continue_condition = False
|
|
||||||
if len(self._scores) >= 5:
|
|
||||||
"""
|
|
||||||
"To prevent the best first search from exploring the entire
|
|
||||||
feature subset search space, a stopping criterion is imposed.
|
|
||||||
The search will terminate if five consecutive fully expanded
|
|
||||||
subsets show no improvement over the current best subset."
|
|
||||||
as stated in Mark A. Hall Thesis
|
|
||||||
"""
|
|
||||||
item_ant = -1
|
|
||||||
for item in self._scores[-5:]:
|
|
||||||
if item_ant == -1:
|
|
||||||
item_ant = item
|
|
||||||
if item > item_ant:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
item_ant = item
|
|
||||||
else:
|
|
||||||
continue_condition = False
|
|
||||||
self._result = candidates
|
self._result = candidates
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _cfs_continue_condition(self, feature_order, candidates):
|
||||||
|
if len(feature_order) == 0 or len(candidates) == self._max_features:
|
||||||
|
# Force leaving the loop
|
||||||
|
return False
|
||||||
|
if len(self._scores) >= 5:
|
||||||
|
"""
|
||||||
|
"To prevent the best first search from exploring the entire
|
||||||
|
feature subset search space, a stopping criterion is imposed.
|
||||||
|
The search will terminate if five consecutive fully expanded
|
||||||
|
subsets show no improvement over the current best subset."
|
||||||
|
as stated in Mark A. Hall Thesis
|
||||||
|
"""
|
||||||
|
item_ant = -1
|
||||||
|
for item in self._scores[-5:]:
|
||||||
|
if item_ant == -1:
|
||||||
|
item_ant = item
|
||||||
|
if item > item_ant:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
item_ant = item
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def fcbf(self, X, y, threshold):
|
def fcbf(self, X, y, threshold):
|
||||||
"""Fast Correlation-Based Filter
|
"""Fast Correlation-Based Filter
|
||||||
|
|
||||||
@@ -286,9 +211,7 @@ class MFS:
|
|||||||
"""
|
"""
|
||||||
if threshold < 1e-7:
|
if threshold < 1e-7:
|
||||||
raise ValueError("Threshold cannot be less than 1e-7")
|
raise ValueError("Threshold cannot be less than 1e-7")
|
||||||
self._initialize()
|
self._initialize(X, y)
|
||||||
self.X_ = X
|
|
||||||
self.y_ = y
|
|
||||||
s_list = self._compute_su_labels()
|
s_list = self._compute_su_labels()
|
||||||
feature_order = (-s_list).argsort()
|
feature_order = (-s_list).argsort()
|
||||||
feature_dup = feature_order.copy().tolist()
|
feature_dup = feature_order.copy().tolist()
|
||||||
@@ -322,7 +245,7 @@ class MFS:
|
|||||||
list
|
list
|
||||||
list of features indices selected
|
list of features indices selected
|
||||||
"""
|
"""
|
||||||
return self._result
|
return self._result if self._fitted else []
|
||||||
|
|
||||||
def get_scores(self):
|
def get_scores(self):
|
||||||
"""Return the scores computed for the features selected
|
"""Return the scores computed for the features selected
|
||||||
@@ -332,4 +255,4 @@ class MFS:
|
|||||||
list
|
list
|
||||||
list of scores of the features selected
|
list of scores of the features selected
|
||||||
"""
|
"""
|
||||||
return self._scores
|
return self._scores if self._fitted else []
|
||||||
|
334
mfs/entropy_estimators.py
Executable file
334
mfs/entropy_estimators.py
Executable file
@@ -0,0 +1,334 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# Written by Greg Ver Steeg
|
||||||
|
# See readme.pdf for documentation
|
||||||
|
# Or go to http://www.isi.edu/~gregv/npeet.html
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.linalg as la
|
||||||
|
from numpy import log
|
||||||
|
from scipy.special import digamma
|
||||||
|
from sklearn.neighbors import BallTree, KDTree
|
||||||
|
|
||||||
|
# CONTINUOUS ESTIMATORS
|
||||||
|
|
||||||
|
|
||||||
|
def entropy(x, k=3, base=2):
|
||||||
|
"""The classic K-L k-nearest neighbor continuous entropy estimator
|
||||||
|
x should be a list of vectors, e.g. x = [[1.3], [3.7], [5.1], [2.4]]
|
||||||
|
if x is a one-dimensional scalar and we have four samples
|
||||||
|
"""
|
||||||
|
assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
|
||||||
|
x = np.asarray(x)
|
||||||
|
n_elements, n_features = x.shape
|
||||||
|
x = add_noise(x)
|
||||||
|
tree = build_tree(x)
|
||||||
|
nn = query_neighbors(tree, x, k)
|
||||||
|
const = digamma(n_elements) - digamma(k) + n_features * log(2)
|
||||||
|
return (const + n_features * np.log(nn).mean()) / log(base)
|
||||||
|
|
||||||
|
|
||||||
|
def centropy(x, y, k=3, base=2):
|
||||||
|
"""The classic K-L k-nearest neighbor continuous entropy estimator for the
|
||||||
|
entropy of X conditioned on Y.
|
||||||
|
"""
|
||||||
|
xy = np.c_[x, y]
|
||||||
|
entropy_union_xy = entropy(xy, k=k, base=base)
|
||||||
|
entropy_y = entropy(y, k=k, base=base)
|
||||||
|
return entropy_union_xy - entropy_y
|
||||||
|
|
||||||
|
|
||||||
|
def tc(xs, k=3, base=2):
|
||||||
|
xs_columns = np.expand_dims(xs, axis=0).T
|
||||||
|
entropy_features = [entropy(col, k=k, base=base) for col in xs_columns]
|
||||||
|
return np.sum(entropy_features) - entropy(xs, k, base)
|
||||||
|
|
||||||
|
|
||||||
|
def ctc(xs, y, k=3, base=2):
|
||||||
|
xs_columns = np.expand_dims(xs, axis=0).T
|
||||||
|
centropy_features = [
|
||||||
|
centropy(col, y, k=k, base=base) for col in xs_columns
|
||||||
|
]
|
||||||
|
return np.sum(centropy_features) - centropy(xs, y, k, base)
|
||||||
|
|
||||||
|
|
||||||
|
def corex(xs, ys, k=3, base=2):
|
||||||
|
xs_columns = np.expand_dims(xs, axis=0).T
|
||||||
|
cmi_features = [mi(col, ys, k=k, base=base) for col in xs_columns]
|
||||||
|
return np.sum(cmi_features) - mi(xs, ys, k=k, base=base)
|
||||||
|
|
||||||
|
|
||||||
|
def mi(x, y, z=None, k=3, base=2, alpha=0):
|
||||||
|
"""Mutual information of x and y (conditioned on z if z is not None)
|
||||||
|
x, y should be a list of vectors, e.g. x = [[1.3], [3.7], [5.1], [2.4]]
|
||||||
|
if x is a one-dimensional scalar and we have four samples
|
||||||
|
"""
|
||||||
|
assert len(x) == len(y), "Arrays should have same length"
|
||||||
|
assert k <= len(x) - 1, "Set k smaller than num. samples - 1"
|
||||||
|
x, y = np.asarray(x), np.asarray(y)
|
||||||
|
x, y = x.reshape(x.shape[0], -1), y.reshape(y.shape[0], -1)
|
||||||
|
x = add_noise(x)
|
||||||
|
y = add_noise(y)
|
||||||
|
points = [x, y]
|
||||||
|
if z is not None:
|
||||||
|
z = np.asarray(z)
|
||||||
|
z = z.reshape(z.shape[0], -1)
|
||||||
|
points.append(z)
|
||||||
|
points = np.hstack(points)
|
||||||
|
# Find nearest neighbors in joint space, p=inf means max-norm
|
||||||
|
tree = build_tree(points)
|
||||||
|
dvec = query_neighbors(tree, points, k)
|
||||||
|
if z is None:
|
||||||
|
a, b, c, d = (
|
||||||
|
avgdigamma(x, dvec),
|
||||||
|
avgdigamma(y, dvec),
|
||||||
|
digamma(k),
|
||||||
|
digamma(len(x)),
|
||||||
|
)
|
||||||
|
if alpha > 0:
|
||||||
|
d += lnc_correction(tree, points, k, alpha)
|
||||||
|
else:
|
||||||
|
xz = np.c_[x, z]
|
||||||
|
yz = np.c_[y, z]
|
||||||
|
a, b, c, d = (
|
||||||
|
avgdigamma(xz, dvec),
|
||||||
|
avgdigamma(yz, dvec),
|
||||||
|
avgdigamma(z, dvec),
|
||||||
|
digamma(k),
|
||||||
|
)
|
||||||
|
return (-a - b + c + d) / log(base)
|
||||||
|
|
||||||
|
|
||||||
|
def cmi(x, y, z, k=3, base=2):
|
||||||
|
"""Mutual information of x and y, conditioned on z
|
||||||
|
Legacy function. Use mi(x, y, z) directly.
|
||||||
|
"""
|
||||||
|
return mi(x, y, z=z, k=k, base=base)
|
||||||
|
|
||||||
|
|
||||||
|
def kldiv(x, xp, k=3, base=2):
|
||||||
|
"""KL Divergence between p and q for x~p(x), xp~q(x)
|
||||||
|
x, xp should be a list of vectors, e.g. x = [[1.3], [3.7], [5.1],[2.4]]
|
||||||
|
if x is a one-dimensional scalar and we have four samples
|
||||||
|
"""
|
||||||
|
assert k < min(len(x), len(xp)), "Set k smaller than num. samples - 1"
|
||||||
|
assert len(x[0]) == len(xp[0]), "Two distributions must have same dim."
|
||||||
|
x, xp = np.asarray(x), np.asarray(xp)
|
||||||
|
x, xp = x.reshape(x.shape[0], -1), xp.reshape(xp.shape[0], -1)
|
||||||
|
d = len(x[0])
|
||||||
|
n = len(x)
|
||||||
|
m = len(xp)
|
||||||
|
const = log(m) - log(n - 1)
|
||||||
|
tree = build_tree(x)
|
||||||
|
treep = build_tree(xp)
|
||||||
|
nn = query_neighbors(tree, x, k)
|
||||||
|
nnp = query_neighbors(treep, x, k - 1)
|
||||||
|
return (const + d * (np.log(nnp).mean() - np.log(nn).mean())) / log(base)
|
||||||
|
|
||||||
|
|
||||||
|
def lnc_correction(tree, points, k, alpha):
|
||||||
|
e = 0
|
||||||
|
n_sample = points.shape[0]
|
||||||
|
for point in points:
|
||||||
|
# Find k-nearest neighbors in joint space, p=inf means max norm
|
||||||
|
knn = tree.query(point[None, :], k=k + 1, return_distance=False)[0]
|
||||||
|
knn_points = points[knn]
|
||||||
|
# Substract mean of k-nearest neighbor points
|
||||||
|
knn_points = knn_points - knn_points[0]
|
||||||
|
# Calculate covariance matrix of k-nearest neighbor points, obtain eigen vectors
|
||||||
|
covr = knn_points.T @ knn_points / k
|
||||||
|
_, v = la.eig(covr)
|
||||||
|
# Calculate PCA-bounding box using eigen vectors
|
||||||
|
V_rect = np.log(np.abs(knn_points @ v).max(axis=0)).sum()
|
||||||
|
# Calculate the volume of original box
|
||||||
|
log_knn_dist = np.log(np.abs(knn_points).max(axis=0)).sum()
|
||||||
|
|
||||||
|
# Perform local non-uniformity checking and update correction term
|
||||||
|
if V_rect < log_knn_dist + np.log(alpha):
|
||||||
|
e += (log_knn_dist - V_rect) / n_sample
|
||||||
|
return e
|
||||||
|
|
||||||
|
|
||||||
|
# DISCRETE ESTIMATORS
|
||||||
|
def entropyd(sx, base=2):
|
||||||
|
"""Discrete entropy estimator
|
||||||
|
sx is a list of samples
|
||||||
|
"""
|
||||||
|
unique, count = np.unique(sx, return_counts=True, axis=0)
|
||||||
|
# Convert to float as otherwise integer division results in all 0 for proba.
|
||||||
|
proba = count.astype(float) / len(sx)
|
||||||
|
# Avoid 0 division; remove probabilities == 0.0 (removing them does not change the entropy estimate as 0 * log(1/0) = 0.
|
||||||
|
proba = proba[proba > 0.0]
|
||||||
|
return np.sum(proba * np.log(1.0 / proba)) / log(base)
|
||||||
|
|
||||||
|
|
||||||
|
def midd(x, y, base=2):
|
||||||
|
"""Discrete mutual information estimator
|
||||||
|
Given a list of samples which can be any hashable object
|
||||||
|
"""
|
||||||
|
assert len(x) == len(y), "Arrays should have same length"
|
||||||
|
return entropyd(x, base) - centropyd(x, y, base)
|
||||||
|
|
||||||
|
|
||||||
|
def cmidd(x, y, z, base=2):
|
||||||
|
"""Discrete mutual information estimator
|
||||||
|
Given a list of samples which can be any hashable object
|
||||||
|
"""
|
||||||
|
assert len(x) == len(y) == len(z), "Arrays should have same length"
|
||||||
|
xz = np.c_[x, z]
|
||||||
|
yz = np.c_[y, z]
|
||||||
|
xyz = np.c_[x, y, z]
|
||||||
|
return (
|
||||||
|
entropyd(xz, base)
|
||||||
|
+ entropyd(yz, base)
|
||||||
|
- entropyd(xyz, base)
|
||||||
|
- entropyd(z, base)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def centropyd(x, y, base=2):
|
||||||
|
"""The classic K-L k-nearest neighbor continuous entropy estimator for the
|
||||||
|
entropy of X conditioned on Y.
|
||||||
|
"""
|
||||||
|
xy = np.c_[x, y]
|
||||||
|
return entropyd(xy, base) - entropyd(y, base)
|
||||||
|
|
||||||
|
|
||||||
|
def tcd(xs, base=2):
|
||||||
|
xs_columns = np.expand_dims(xs, axis=0).T
|
||||||
|
entropy_features = [entropyd(col, base=base) for col in xs_columns]
|
||||||
|
return np.sum(entropy_features) - entropyd(xs, base)
|
||||||
|
|
||||||
|
|
||||||
|
def ctcd(xs, y, base=2):
|
||||||
|
xs_columns = np.expand_dims(xs, axis=0).T
|
||||||
|
centropy_features = [centropyd(col, y, base=base) for col in xs_columns]
|
||||||
|
return np.sum(centropy_features) - centropyd(xs, y, base)
|
||||||
|
|
||||||
|
|
||||||
|
def corexd(xs, ys, base=2):
|
||||||
|
xs_columns = np.expand_dims(xs, axis=0).T
|
||||||
|
cmi_features = [midd(col, ys, base=base) for col in xs_columns]
|
||||||
|
return np.sum(cmi_features) - midd(xs, ys, base)
|
||||||
|
|
||||||
|
|
||||||
|
# MIXED ESTIMATORS
|
||||||
|
def micd(x, y, k=3, base=2, warning=True):
|
||||||
|
"""If x is continuous and y is discrete, compute mutual information"""
|
||||||
|
assert len(x) == len(y), "Arrays should have same length"
|
||||||
|
entropy_x = entropy(x, k, base)
|
||||||
|
|
||||||
|
y_unique, y_count = np.unique(y, return_counts=True, axis=0)
|
||||||
|
y_proba = y_count / len(y)
|
||||||
|
|
||||||
|
entropy_x_given_y = 0.0
|
||||||
|
for yval, py in zip(y_unique, y_proba):
|
||||||
|
x_given_y = x[(y == yval).all(axis=1)]
|
||||||
|
if k <= len(x_given_y) - 1:
|
||||||
|
entropy_x_given_y += py * entropy(x_given_y, k, base)
|
||||||
|
else:
|
||||||
|
if warning:
|
||||||
|
warnings.warn(
|
||||||
|
"Warning, after conditioning, on y={yval} insufficient data. "
|
||||||
|
"Assuming maximal entropy in this case.".format(yval=yval)
|
||||||
|
)
|
||||||
|
entropy_x_given_y += py * entropy_x
|
||||||
|
return abs(entropy_x - entropy_x_given_y) # units already applied
|
||||||
|
|
||||||
|
|
||||||
|
def midc(x, y, k=3, base=2, warning=True):
|
||||||
|
return micd(y, x, k, base, warning)
|
||||||
|
|
||||||
|
|
||||||
|
def centropycd(x, y, k=3, base=2, warning=True):
|
||||||
|
return entropy(x, base) - micd(x, y, k, base, warning)
|
||||||
|
|
||||||
|
|
||||||
|
def centropydc(x, y, k=3, base=2, warning=True):
|
||||||
|
return centropycd(y, x, k=k, base=base, warning=warning)
|
||||||
|
|
||||||
|
|
||||||
|
def ctcdc(xs, y, k=3, base=2, warning=True):
|
||||||
|
xs_columns = np.expand_dims(xs, axis=0).T
|
||||||
|
centropy_features = [
|
||||||
|
centropydc(col, y, k=k, base=base, warning=warning)
|
||||||
|
for col in xs_columns
|
||||||
|
]
|
||||||
|
return np.sum(centropy_features) - centropydc(xs, y, k, base, warning)
|
||||||
|
|
||||||
|
|
||||||
|
def ctccd(xs, y, k=3, base=2, warning=True):
|
||||||
|
return ctcdc(y, xs, k=k, base=base, warning=warning)
|
||||||
|
|
||||||
|
|
||||||
|
def corexcd(xs, ys, k=3, base=2, warning=True):
|
||||||
|
return corexdc(ys, xs, k=k, base=base, warning=warning)
|
||||||
|
|
||||||
|
|
||||||
|
def corexdc(xs, ys, k=3, base=2, warning=True):
|
||||||
|
return tcd(xs, base) - ctcdc(xs, ys, k, base, warning)
|
||||||
|
|
||||||
|
|
||||||
|
# UTILITY FUNCTIONS
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise(x, intens=1e-10):
|
||||||
|
# small noise to break degeneracy, see doc.
|
||||||
|
return x + intens * np.random.random_sample(x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def query_neighbors(tree, x, k):
|
||||||
|
return tree.query(x, k=k + 1)[0][:, k]
|
||||||
|
|
||||||
|
|
||||||
|
def count_neighbors(tree, x, r):
|
||||||
|
return tree.query_radius(x, r, count_only=True)
|
||||||
|
|
||||||
|
|
||||||
|
def avgdigamma(points, dvec):
|
||||||
|
# This part finds number of neighbors in some radius in the marginal space
|
||||||
|
# returns expectation value of <psi(nx)>
|
||||||
|
tree = build_tree(points)
|
||||||
|
dvec = dvec - 1e-15
|
||||||
|
num_points = count_neighbors(tree, points, dvec)
|
||||||
|
return np.mean(digamma(num_points))
|
||||||
|
|
||||||
|
|
||||||
|
def build_tree(points):
|
||||||
|
if points.shape[1] >= 20:
|
||||||
|
return BallTree(points, metric="chebyshev")
|
||||||
|
return KDTree(points, metric="chebyshev")
|
||||||
|
|
||||||
|
|
||||||
|
# TESTS
|
||||||
|
|
||||||
|
|
||||||
|
def shuffle_test(measure, x, y, z=False, ns=200, ci=0.95, **kwargs):
|
||||||
|
"""Shuffle test
|
||||||
|
Repeatedly shuffle the x-values and then estimate measure(x, y, [z]).
|
||||||
|
Returns the mean and conf. interval ('ci=0.95' default) over 'ns' runs.
|
||||||
|
'measure' could me mi, cmi, e.g. Keyword arguments can be passed.
|
||||||
|
Mutual information and CMI should have a mean near zero.
|
||||||
|
"""
|
||||||
|
x_clone = np.copy(x) # A copy that we can shuffle
|
||||||
|
outputs = []
|
||||||
|
for _ in range(ns):
|
||||||
|
np.random.shuffle(x_clone)
|
||||||
|
if z:
|
||||||
|
outputs.append(measure(x_clone, y, z, **kwargs))
|
||||||
|
else:
|
||||||
|
outputs.append(measure(x_clone, y, **kwargs))
|
||||||
|
outputs.sort()
|
||||||
|
return np.mean(outputs), (
|
||||||
|
outputs[int((1.0 - ci) / 2 * ns)],
|
||||||
|
outputs[int((1.0 + ci) / 2 * ns)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("MI between two independent continuous random variables X and Y:")
|
||||||
|
np.random.seed(0)
|
||||||
|
x = np.random.randn(1000, 10)
|
||||||
|
y = np.random.randn(1000, 3)
|
||||||
|
print(mi(x, y, base=2, alpha=0))
|
@@ -9,11 +9,11 @@ class MFS_test(unittest.TestCase):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
mdlp = MDLP(random_state=1)
|
mdlp = MDLP(random_state=1)
|
||||||
X, self.y_w = load_wine(return_X_y=True)
|
self.X_wc, self.y_w = load_wine(return_X_y=True)
|
||||||
self.X_w = mdlp.fit_transform(X, self.y_w).astype("int64")
|
self.X_w = mdlp.fit_transform(self.X_wc, self.y_w).astype("int64")
|
||||||
X, self.y_i = load_iris(return_X_y=True)
|
self.X_ic, self.y_i = load_iris(return_X_y=True)
|
||||||
mdlp = MDLP(random_state=1)
|
mdlp = MDLP(random_state=1)
|
||||||
self.X_i = mdlp.fit_transform(X, self.y_i).astype("int64")
|
self.X_i = mdlp.fit_transform(self.X_ic, self.y_i).astype("int64")
|
||||||
|
|
||||||
def assertListAlmostEqual(self, list1, list2, tol=7):
|
def assertListAlmostEqual(self, list1, list2, tol=7):
|
||||||
self.assertEqual(len(list1), len(list2))
|
self.assertEqual(len(list1), len(list2))
|
||||||
@@ -21,16 +21,16 @@ class MFS_test(unittest.TestCase):
|
|||||||
self.assertAlmostEqual(a, b, tol)
|
self.assertAlmostEqual(a, b, tol)
|
||||||
|
|
||||||
def test_initialize(self):
|
def test_initialize(self):
|
||||||
mfs = MFS(max_features=100)
|
mfs = MFS()
|
||||||
mfs.fcbf(self.X_w, self.y_w, 0.05)
|
mfs.fcbf(self.X_w, self.y_w, 0.05)
|
||||||
mfs._initialize()
|
mfs._initialize(self.X_w, self.y_w)
|
||||||
self.assertIsNone(mfs.get_results())
|
self.assertIsNone(mfs.get_results())
|
||||||
self.assertListEqual([], mfs.get_scores())
|
self.assertListEqual([], mfs.get_scores())
|
||||||
self.assertDictEqual({}, mfs._su_features)
|
self.assertDictEqual({}, mfs._su_features)
|
||||||
self.assertIsNone(mfs._su_labels)
|
self.assertIsNone(mfs._su_labels)
|
||||||
|
|
||||||
def test_csf_wine(self):
|
def test_csf_wine(self):
|
||||||
mfs = MFS(max_features=100)
|
mfs = MFS()
|
||||||
expected = [6, 12, 9, 4, 10, 0]
|
expected = [6, 12, 9, 4, 10, 0]
|
||||||
self.assertListAlmostEqual(
|
self.assertListAlmostEqual(
|
||||||
expected, mfs.cfs(self.X_w, self.y_w).get_results()
|
expected, mfs.cfs(self.X_w, self.y_w).get_results()
|
||||||
@@ -45,6 +45,23 @@ class MFS_test(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertListAlmostEqual(expected, mfs.get_scores())
|
self.assertListAlmostEqual(expected, mfs.get_scores())
|
||||||
|
|
||||||
|
def test_csf_wine_cont(self):
|
||||||
|
mfs = MFS(discrete=False)
|
||||||
|
expected = [6, 11, 9, 0, 12, 5]
|
||||||
|
self.assertListAlmostEqual(
|
||||||
|
expected, mfs.cfs(self.X_wc, self.y_w).get_results()
|
||||||
|
)
|
||||||
|
expected = [
|
||||||
|
0.5218299405215557,
|
||||||
|
0.602513857132804,
|
||||||
|
0.4877384978817362,
|
||||||
|
0.3743688234383051,
|
||||||
|
0.28795671854246285,
|
||||||
|
0.2309165735173175,
|
||||||
|
]
|
||||||
|
# self.assertListAlmostEqual(expected, mfs.get_scores())
|
||||||
|
print(expected, mfs.get_scores())
|
||||||
|
|
||||||
def test_csf_max_features(self):
|
def test_csf_max_features(self):
|
||||||
mfs = MFS(max_features=3)
|
mfs = MFS(max_features=3)
|
||||||
expected = [6, 12, 9]
|
expected = [6, 12, 9]
|
||||||
@@ -59,7 +76,7 @@ class MFS_test(unittest.TestCase):
|
|||||||
self.assertListAlmostEqual(expected, mfs.get_scores())
|
self.assertListAlmostEqual(expected, mfs.get_scores())
|
||||||
|
|
||||||
def test_csf_iris(self):
|
def test_csf_iris(self):
|
||||||
mfs = MFS(max_features=100)
|
mfs = MFS()
|
||||||
expected = [3, 2, 0, 1]
|
expected = [3, 2, 0, 1]
|
||||||
computed = mfs.cfs(self.X_i, self.y_i).get_results()
|
computed = mfs.cfs(self.X_i, self.y_i).get_results()
|
||||||
self.assertListAlmostEqual(expected, computed)
|
self.assertListAlmostEqual(expected, computed)
|
||||||
@@ -72,7 +89,7 @@ class MFS_test(unittest.TestCase):
|
|||||||
self.assertListAlmostEqual(expected, mfs.get_scores())
|
self.assertListAlmostEqual(expected, mfs.get_scores())
|
||||||
|
|
||||||
def test_fcbf_wine(self):
|
def test_fcbf_wine(self):
|
||||||
mfs = MFS(max_features=100)
|
mfs = MFS()
|
||||||
computed = mfs.fcbf(self.X_w, self.y_w, threshold=0.05).get_results()
|
computed = mfs.fcbf(self.X_w, self.y_w, threshold=0.05).get_results()
|
||||||
expected = [6, 9, 12, 0, 11, 4]
|
expected = [6, 9, 12, 0, 11, 4]
|
||||||
self.assertListAlmostEqual(expected, computed)
|
self.assertListAlmostEqual(expected, computed)
|
||||||
@@ -99,7 +116,7 @@ class MFS_test(unittest.TestCase):
|
|||||||
self.assertListAlmostEqual(expected, mfs.get_scores())
|
self.assertListAlmostEqual(expected, mfs.get_scores())
|
||||||
|
|
||||||
def test_fcbf_iris(self):
|
def test_fcbf_iris(self):
|
||||||
mfs = MFS(max_features=100)
|
mfs = MFS()
|
||||||
computed = mfs.fcbf(self.X_i, self.y_i, threshold=0.05).get_results()
|
computed = mfs.fcbf(self.X_i, self.y_i, threshold=0.05).get_results()
|
||||||
expected = [3, 2]
|
expected = [3, 2]
|
||||||
self.assertListAlmostEqual(expected, computed)
|
self.assertListAlmostEqual(expected, computed)
|
||||||
@@ -107,7 +124,7 @@ class MFS_test(unittest.TestCase):
|
|||||||
self.assertListAlmostEqual(expected, mfs.get_scores())
|
self.assertListAlmostEqual(expected, mfs.get_scores())
|
||||||
|
|
||||||
def test_compute_su_labels(self):
|
def test_compute_su_labels(self):
|
||||||
mfs = MFS(max_features=100)
|
mfs = MFS()
|
||||||
mfs.fcbf(self.X_i, self.y_i, threshold=0.05)
|
mfs.fcbf(self.X_i, self.y_i, threshold=0.05)
|
||||||
expected = [0.0, 0.0, 0.810724587460511, 0.870521418179061]
|
expected = [0.0, 0.0, 0.810724587460511, 0.870521418179061]
|
||||||
self.assertListAlmostEqual(expected, mfs._compute_su_labels().tolist())
|
self.assertListAlmostEqual(expected, mfs._compute_su_labels().tolist())
|
||||||
@@ -115,12 +132,12 @@ class MFS_test(unittest.TestCase):
|
|||||||
self.assertListAlmostEqual([1, 2, 3, 4], mfs._compute_su_labels())
|
self.assertListAlmostEqual([1, 2, 3, 4], mfs._compute_su_labels())
|
||||||
|
|
||||||
def test_invalid_threshold(self):
|
def test_invalid_threshold(self):
|
||||||
mfs = MFS(max_features=100)
|
mfs = MFS()
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
mfs.fcbf(self.X_i, self.y_i, threshold=1e-15)
|
mfs.fcbf(self.X_i, self.y_i, threshold=1e-15)
|
||||||
|
|
||||||
def test_fcbf_exit_threshold(self):
|
def test_fcbf_exit_threshold(self):
|
||||||
mfs = MFS(max_features=100)
|
mfs = MFS()
|
||||||
computed = mfs.fcbf(self.X_w, self.y_w, threshold=0.4).get_results()
|
computed = mfs.fcbf(self.X_w, self.y_w, threshold=0.4).get_results()
|
||||||
expected = [6, 9, 12]
|
expected = [6, 9, 12]
|
||||||
self.assertListAlmostEqual(expected, computed)
|
self.assertListAlmostEqual(expected, computed)
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from sklearn.datasets import load_iris
|
import numpy as np
|
||||||
|
from sklearn.datasets import load_iris, load_wine
|
||||||
from mdlp import MDLP
|
from mdlp import MDLP
|
||||||
from ..Selection import Metrics
|
from ..Selection import Metrics
|
||||||
|
|
||||||
@@ -8,12 +9,10 @@ class Metrics_test(unittest.TestCase):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
mdlp = MDLP(random_state=1)
|
mdlp = MDLP(random_state=1)
|
||||||
X, self.y = load_iris(return_X_y=True)
|
self.X_i_c, self.y_i = load_iris(return_X_y=True)
|
||||||
self.X = mdlp.fit_transform(X, self.y).astype("int64")
|
self.X_i = mdlp.fit_transform(self.X_i_c, self.y_i).astype("int64")
|
||||||
self.m, self.n = self.X.shape
|
self.X_w_c, self.y_w = load_wine(return_X_y=True)
|
||||||
|
self.X_w = mdlp.fit_transform(self.X_w_c, self.y_w).astype("int64")
|
||||||
# @classmethod
|
|
||||||
# def setup(cls):
|
|
||||||
|
|
||||||
def test_entropy(self):
|
def test_entropy(self):
|
||||||
metric = Metrics()
|
metric = Metrics()
|
||||||
@@ -24,12 +23,51 @@ class Metrics_test(unittest.TestCase):
|
|||||||
([1, 1, 1, 5, 2, 2, 3, 3, 3], 4, 0.9455305560363263),
|
([1, 1, 1, 5, 2, 2, 3, 3, 3], 4, 0.9455305560363263),
|
||||||
([1, 1, 1, 2, 2, 3, 3, 3, 5], 4, 0.9455305560363263),
|
([1, 1, 1, 2, 2, 3, 3, 3, 5], 4, 0.9455305560363263),
|
||||||
([1, 1, 5], 2, 0.9182958340544896),
|
([1, 1, 5], 2, 0.9182958340544896),
|
||||||
(self.y, 3, 0.999999999),
|
(self.y_i, 3, 0.999999999),
|
||||||
]
|
]
|
||||||
for dataset, base, entropy in datasets:
|
for dataset, base, entropy in datasets:
|
||||||
computed = metric.entropy(dataset, base)
|
computed = metric.entropy(dataset, base)
|
||||||
self.assertAlmostEqual(entropy, computed)
|
self.assertAlmostEqual(entropy, computed)
|
||||||
|
|
||||||
|
def test_differential_entropy(self):
|
||||||
|
metric = Metrics()
|
||||||
|
datasets = [
|
||||||
|
([0, 0, 0, 0, 1, 1, 1, 1], 6, 1.0026709900837547096),
|
||||||
|
([0, 1, 0, 2, 1, 2], 5, 1.3552453009332424),
|
||||||
|
([0, 0, 0, 0, 0, 0, 0, 2, 2, 2], 7, 1.7652626150881443),
|
||||||
|
([1, 1, 1, 5, 2, 2, 3, 3, 3], 8, 1.9094631320594582),
|
||||||
|
([1, 1, 1, 2, 2, 3, 3, 3, 5], 8, 1.9094631320594582),
|
||||||
|
([1, 1, 5], 2, 2.5794415416798357),
|
||||||
|
(self.X_i_c, 37, 3.06627326925228),
|
||||||
|
(self.X_w_c, 37, 63.13827518897429),
|
||||||
|
]
|
||||||
|
for dataset, base, entropy in datasets:
|
||||||
|
computed = metric.differential_entropy(
|
||||||
|
np.array(dataset, dtype="float64"), base
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(entropy, computed, msg=str(dataset))
|
||||||
|
expected = [
|
||||||
|
1.6378708764142766,
|
||||||
|
2.0291571802275037,
|
||||||
|
0.8273865123744271,
|
||||||
|
3.203935772642847,
|
||||||
|
4.859193341386733,
|
||||||
|
1.3707315434976266,
|
||||||
|
1.8794952925706312,
|
||||||
|
-0.2983180654207054,
|
||||||
|
1.4521478934625076,
|
||||||
|
2.834404839362728,
|
||||||
|
0.4894081282811191,
|
||||||
|
1.361210381692561,
|
||||||
|
7.6373991502818175,
|
||||||
|
]
|
||||||
|
n_samples, n_features = self.X_w_c.shape
|
||||||
|
for c, res_expected in zip(range(n_features), expected):
|
||||||
|
computed = metric.differential_entropy(
|
||||||
|
self.X_w_c[:, c], n_samples - 1
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(computed, res_expected)
|
||||||
|
|
||||||
def test_conditional_entropy(self):
|
def test_conditional_entropy(self):
|
||||||
metric = Metrics()
|
metric = Metrics()
|
||||||
results_expected = [
|
results_expected = [
|
||||||
@@ -39,7 +77,7 @@ class Metrics_test(unittest.TestCase):
|
|||||||
0.13032469395094992,
|
0.13032469395094992,
|
||||||
]
|
]
|
||||||
for expected, col in zip(results_expected, range(self.n)):
|
for expected, col in zip(results_expected, range(self.n)):
|
||||||
computed = metric.conditional_entropy(self.X[:, col], self.y, 3)
|
computed = metric.conditional_entropy(self.X_i[:, col], self.y, 3)
|
||||||
self.assertAlmostEqual(expected, computed)
|
self.assertAlmostEqual(expected, computed)
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
0.6309297535714573,
|
0.6309297535714573,
|
||||||
@@ -62,7 +100,7 @@ class Metrics_test(unittest.TestCase):
|
|||||||
0.8696753060490499,
|
0.8696753060490499,
|
||||||
]
|
]
|
||||||
for expected, col in zip(results_expected, range(self.n)):
|
for expected, col in zip(results_expected, range(self.n)):
|
||||||
computed = metric.information_gain(self.X[:, col], self.y, 3)
|
computed = metric.information_gain(self.X_i[:, col], self.y, 3)
|
||||||
self.assertAlmostEqual(expected, computed)
|
self.assertAlmostEqual(expected, computed)
|
||||||
# https://planetcalc.com/8419/
|
# https://planetcalc.com/8419/
|
||||||
# ?_d=FrDfFN2COAhqh9Pb5ycqy5CeKgIOxlfSjKgyyIR.Q5L0np-g-hw6yv8M1Q8_
|
# ?_d=FrDfFN2COAhqh9Pb5ycqy5CeKgIOxlfSjKgyyIR.Q5L0np-g-hw6yv8M1Q8_
|
||||||
@@ -73,7 +111,7 @@ class Metrics_test(unittest.TestCase):
|
|||||||
1.378402748,
|
1.378402748,
|
||||||
]
|
]
|
||||||
for expected, col in zip(results_expected, range(self.n)):
|
for expected, col in zip(results_expected, range(self.n)):
|
||||||
computed = metric.information_gain(self.X[:, col], self.y, 2)
|
computed = metric.information_gain(self.X_i[:, col], self.y, 2)
|
||||||
self.assertAlmostEqual(expected, computed)
|
self.assertAlmostEqual(expected, computed)
|
||||||
|
|
||||||
def test_symmetrical_uncertainty(self):
|
def test_symmetrical_uncertainty(self):
|
||||||
@@ -85,5 +123,20 @@ class Metrics_test(unittest.TestCase):
|
|||||||
0.870521418179061,
|
0.870521418179061,
|
||||||
]
|
]
|
||||||
for expected, col in zip(results_expected, range(self.n)):
|
for expected, col in zip(results_expected, range(self.n)):
|
||||||
computed = metric.symmetrical_uncertainty(self.X[:, col], self.y)
|
computed = metric.symmetrical_uncertainty(self.X_i[:, col], self.y)
|
||||||
self.assertAlmostEqual(expected, computed)
|
self.assertAlmostEqual(expected, computed)
|
||||||
|
|
||||||
|
def test_symmetrical_uncertainty_continuous(self):
|
||||||
|
metric = Metrics()
|
||||||
|
results_expected = [
|
||||||
|
0.33296547388990266,
|
||||||
|
0.19068147573570668,
|
||||||
|
0.810724587460511,
|
||||||
|
0.870521418179061,
|
||||||
|
]
|
||||||
|
for expected, col in zip(results_expected, range(self.n)):
|
||||||
|
computed = metric.symmetrical_unc_continuous(
|
||||||
|
self.X_i[:, col], self.y
|
||||||
|
)
|
||||||
|
print(computed)
|
||||||
|
# self.assertAlmostEqual(expected, computed)
|
||||||
|
Reference in New Issue
Block a user