#6 - Add multiclass support

Removed (by now) predict_proba. Created a notebook in jupyter
Added split_criteria parameter with min_distance and max_samples values
Refactor _distances
Refactor _split_criteria
Refactor _reorder_results
This commit is contained in:
2020-06-11 13:10:52 +02:00
parent 45510b43bc
commit f360a2640c
3 changed files with 156 additions and 199 deletions

11
main.py
View File

@@ -75,14 +75,3 @@ print(f"Took {time.time() - now:.2f} seconds to train")
print(clf) print(clf)
print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}") print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}")
print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}") print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}")
proba = clf.predict_proba(Xtest)
print(
"Checking that we have correct probabilities, these are probabilities of "
"sample belonging to class 1"
)
res0 = proba[proba[:, 0] == 0]
res1 = proba[proba[:, 0] == 1]
print("++++++++++res0 > .8++++++++++++")
print(res0[res0[:, 1] > 0.8])
print("**********res1 < .4************")
print(res1[res1[:, 1] < 0.4])

View File

@@ -131,6 +131,7 @@ class Stree(BaseEstimator, ClassifierMixin):
tol: float = 1e-4, tol: float = 1e-4,
degree: int = 3, degree: int = 3,
gamma="scale", gamma="scale",
split_criteria="max_samples",
min_samples_split: int = 0, min_samples_split: int = 0,
): ):
self.max_iter = max_iter self.max_iter = max_iter
@@ -142,17 +143,18 @@ class Stree(BaseEstimator, ClassifierMixin):
self.gamma = gamma self.gamma = gamma
self.degree = degree self.degree = degree
self.min_samples_split = min_samples_split self.min_samples_split = min_samples_split
self.split_criteria = split_criteria
def _more_tags(self) -> dict: def _more_tags(self) -> dict:
"""Required by sklearn to tell that this estimator is a binary classifier """Required by sklearn to supply features of the classifier
:return: the tag required :return: the tag required
:rtype: dict :rtype: dict
""" """
return {"binary_only": True, "requires_y": True} return {"requires_y": True}
def _split_array(self, origin: np.array, down: np.array) -> list: def _split_array(self, origin: np.array, down: np.array) -> list:
"""Split an array in two based on indices passed as down and its complement """Split an array in two based on indices (down) and its complement
:param origin: dataset to split :param origin: dataset to split
:type origin: np.array :type origin: np.array
@@ -163,8 +165,8 @@ class Stree(BaseEstimator, ClassifierMixin):
""" """
up = ~down up = ~down
return ( return (
origin[up[:, 0]] if any(up) else None, origin[up] if any(up) else None,
origin[down[:, 0]] if any(down) else None, origin[down] if any(down) else None,
) )
def _distances(self, node: Snode, data: np.ndarray) -> np.array: def _distances(self, node: Snode, data: np.ndarray) -> np.array:
@@ -178,27 +180,38 @@ class Stree(BaseEstimator, ClassifierMixin):
the hyperplane of the node the hyperplane of the node
:rtype: np.array :rtype: np.array
""" """
res = node._clf.decision_function(data) return node._clf.decision_function(data)
if res.ndim == 1:
return np.expand_dims(res, 1)
elif res.shape[1] > 1:
# remove multiclass info
res = np.delete(res, slice(1, res.shape[1]), axis=1)
return res
def _split_criteria(self, data: np.array) -> np.array: def _min_distance(self, data: np.array, _) -> np.array:
# chooses the lowest distance of every sample
indices = np.argmin(np.abs(data), axis=1)
return np.take(data, indices)
def _max_samples(self, data: np.array, y: np.array) -> np.array:
# select the class with max number of samples
_, samples = np.unique(y, return_counts=True)
selected = np.argmax(samples)
return data[:, selected]
def _split_criteria(self, data: np.array, node: Snode) -> np.array:
"""Set the criteria to split arrays """Set the criteria to split arrays
:param data: [description] :param data: distances of samples to hyperplanes shape (m, nclasses)
if nclasses > 2 else (m,)
:type data: np.array :type data: np.array
:return: [description] :param node: node containing the svm classifier
:type node: Snode
:return: array of booleans of samples under or above zero
:rtype: np.array :rtype: np.array
""" """
return (
data > 0 if data.shape[0] < self.min_samples_split:
if data.shape[0] >= self.min_samples_split return np.ones((data.shape[0]), dtype=bool)
else np.ones((data.shape[0], 1), dtype=bool) if data.ndim > 1:
) # split criteria for multiclass
data = getattr(self, f"_{self.split_criteria}")(data, node._y)
res = data > 0
return res
def fit( def fit(
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
@@ -231,12 +244,19 @@ class Stree(BaseEstimator, ClassifierMixin):
f"Maximum depth has to be greater than 1... got (max_depth=\ f"Maximum depth has to be greater than 1... got (max_depth=\
{self.max_depth})" {self.max_depth})"
) )
if self.split_criteria not in ["min_distance", "max_samples"]:
raise ValueError(
f"split_criteria has to be min_distance or \
max_samples got ({self.split_criteria})"
)
check_classification_targets(y) check_classification_targets(y)
X, y = check_X_y(X, y) X, y = check_X_y(X, y)
sample_weight = _check_sample_weight(sample_weight, X) sample_weight = _check_sample_weight(sample_weight, X)
check_classification_targets(y) check_classification_targets(y)
# Initialize computed parameters # Initialize computed parameters
self.classes_, y = np.unique(y, return_inverse=True) self.classes_, y = np.unique(y, return_inverse=True)
self.n_classes_ = self.classes_.shape[0]
self.n_iter_ = self.max_iter self.n_iter_ = self.max_iter
self.depth_ = 0 self.depth_ = 0
self.n_features_in_ = X.shape[1] self.n_features_in_ = X.shape[1]
@@ -244,6 +264,52 @@ class Stree(BaseEstimator, ClassifierMixin):
self._build_predictor() self._build_predictor()
return self return self
def train(
self,
X: np.ndarray,
y: np.ndarray,
sample_weight: np.ndarray,
depth: int,
title: str,
) -> Snode:
"""Recursive function to split the original dataset into predictor
nodes (leaves)
:param X: samples dataset
:type X: np.ndarray
:param y: samples labels
:type y: np.ndarray
:param sample_weight: weight of samples. Rescale C per sample.
Hi weights force the classifier to put more emphasis on these points.
:type sample_weight: np.ndarray
:param depth: actual depth in the tree
:type depth: int
:param title: description of the node
:type title: str
:return: binary tree
:rtype: Snode
"""
if depth > self.__max_depth:
return None
if np.unique(y).shape[0] == 1:
# only 1 class => pure dataset
return Snode(None, X, y, title + ", <pure>")
# Train the model
clf = self._build_clf()
clf.fit(X, y, sample_weight=sample_weight)
node = Snode(clf, X, y, title)
self.depth_ = max(depth, self.depth_)
down = self._split_criteria(self._distances(node, X), node)
X_U, X_D = self._split_array(X, down)
y_u, y_d = self._split_array(y, down)
sw_u, sw_d = self._split_array(sample_weight, down)
if X_U is None or X_D is None:
# didn't part anything
return Snode(clf, X, y, title + ", <cgaf>")
node.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up"))
node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
return node
def _build_predictor(self): def _build_predictor(self):
"""Process the leaves to make them predictors """Process the leaves to make them predictors
""" """
@@ -278,52 +344,6 @@ class Stree(BaseEstimator, ClassifierMixin):
) )
) )
def train(
self,
X: np.ndarray,
y: np.ndarray,
sample_weight: np.ndarray,
depth: int,
title: str,
) -> Snode:
"""Recursive function to split the original dataset into predictor
nodes (leaves)
:param X: samples dataset
:type X: np.ndarray
:param y: samples labels
:type y: np.ndarray
:param sample_weight: weight of samples. Rescale C per sample.
Hi weights force the classifier to put more emphasis on these points.
:type sample_weight: np.ndarray
:param depth: actual depth in the tree
:type depth: int
:param title: description of the node
:type title: str
:return: binary tree
:rtype: Snode
"""
if depth > self.__max_depth:
return None
if np.unique(y).shape[0] == 1:
# only 1 class => pure dataset
return Snode(None, X, y, title + ", <pure>")
# Train the model
clf = self._build_clf()
clf.fit(X, y, sample_weight=sample_weight)
tree = Snode(clf, X, y, title)
self.depth_ = max(depth, self.depth_)
down = self._split_criteria(self._distances(tree, X))
X_U, X_D = self._split_array(X, down)
y_u, y_d = self._split_array(y, down)
sw_u, sw_d = self._split_array(sample_weight, down)
if X_U is None or X_D is None:
# didn't part anything
return Snode(clf, X, y, title + ", <cgaf>")
tree.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up"))
tree.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
return tree
def _reorder_results(self, y: np.array, indices: np.array) -> np.array: def _reorder_results(self, y: np.array, indices: np.array) -> np.array:
"""Reorder an array based on the array of indices passed """Reorder an array based on the array of indices passed
@@ -334,10 +354,6 @@ class Stree(BaseEstimator, ClassifierMixin):
:return: array y ordered :return: array y ordered
:rtype: np.array :rtype: np.array
""" """
if y.ndim > 1 and y.shape[1] > 1:
# if predict_proba return np.array of floats
y_ordered = np.zeros(y.shape, dtype=float)
else:
# return array of same type given in y # return array of same type given in y
y_ordered = y.copy() y_ordered = y.copy()
indices = indices.astype(int) indices = indices.astype(int)
@@ -363,11 +379,11 @@ class Stree(BaseEstimator, ClassifierMixin):
# set a class for every sample in dataset # set a class for every sample in dataset
prediction = np.full((xp.shape[0], 1), node._class) prediction = np.full((xp.shape[0], 1), node._class)
return prediction, indices return prediction, indices
down = self._split_criteria(self._distances(node, xp)) down = self._split_criteria(self._distances(node, xp), node)
X_U, X_D = self._split_array(xp, down) x_u, x_d = self._split_array(xp, down)
i_u, i_d = self._split_array(indices, down) i_u, i_d = self._split_array(indices, down)
prx_u, prin_u = predict_class(X_U, i_u, node.get_up()) prx_u, prin_u = predict_class(x_u, i_u, node.get_up())
prx_d, prin_d = predict_class(X_D, i_d, node.get_down()) prx_d, prin_d = predict_class(x_d, i_d, node.get_down())
return np.append(prx_u, prx_d), np.append(prin_u, prin_d) return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
# sklearn check # sklearn check
@@ -383,68 +399,6 @@ class Stree(BaseEstimator, ClassifierMixin):
) )
return self.classes_[result] return self.classes_[result]
def predict_proba(self, X: np.array) -> np.array:
"""Computes an approximation of the probability of samples belonging to
class 0 and 1
:param X: dataset
:type X: np.array
:return: array array of shape (m, num_classes), probability of being
each class
:rtype: np.array
"""
def predict_class(
xp: np.array, indices: np.array, dist: np.array, node: Snode
) -> np.array:
"""Run the tree to compute predictions
:param xp: subdataset of samples
:type xp: np.array
:param indices: indices of subdataset samples to rebuild original
order
:type indices: np.array
:param dist: distances of every sample to the hyperplane or the
father node
:type dist: np.array
:param node: node of the leaf with the class
:type node: Snode
:return: array of labels and distances, array of indices
:rtype: np.array
"""
if xp is None:
return [], []
if node.is_leaf():
# set a class for every sample in dataset
prediction = np.full((xp.shape[0], 1), node._class)
prediction_proba = dist
return np.append(prediction, prediction_proba, axis=1), indices
distances = self._distances(node, xp)
down = self._split_criteria(distances)
X_U, X_D = self._split_array(xp, down)
i_u, i_d = self._split_array(indices, down)
di_u, di_d = self._split_array(distances, down)
prx_u, prin_u = predict_class(X_U, i_u, di_u, node.get_up())
prx_d, prin_d = predict_class(X_D, i_d, di_d, node.get_down())
return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
# sklearn check
check_is_fitted(self, ["tree_"])
# Input validation
X = check_array(X)
# setup prediction & make it happen
indices = np.arange(X.shape[0])
empty_dist = np.empty((X.shape[0], 1), dtype=float)
result, indices = predict_class(X, indices, empty_dist, self.tree_)
result = result.reshape(X.shape[0], 2)
# Turn distances to hyperplane into probabilities based on fitting
# distances of samples to its hyperplane that classified them, to the
# sigmoid function
# Probability of being 1
result[:, 1] = 1 / (1 + np.exp(-result[:, 1]))
# Probability of being 0
result[:, 0] = 1 - result[:, 1]
return self._reorder_results(result, indices)
def score( def score(
self, X: np.array, y: np.array, sample_weight: np.array = None self, X: np.array, y: np.array, sample_weight: np.array = None
) -> float: ) -> float:
@@ -473,12 +427,11 @@ class Stree(BaseEstimator, ClassifierMixin):
score = differing_labels == 0 score = differing_labels == 0
else: else:
score = y_true == y_pred score = y_true == y_pred
return _weighted_sum(score, sample_weight, normalize=True) return _weighted_sum(score, sample_weight, normalize=True)
def __iter__(self) -> Siterator: def __iter__(self) -> Siterator:
"""Create an iterator to be able to visit the nodes of the tree in preorder, """Create an iterator to be able to visit the nodes of the tree in
can make a list with all the nodes in preorder preorder, can make a list with all the nodes in preorder
:return: an iterator, can for i in... and list(...) :return: an iterator, can for i in... and list(...)
:rtype: Siterator :rtype: Siterator

View File

@@ -2,23 +2,22 @@ import os
import unittest import unittest
import numpy as np import numpy as np
from sklearn.datasets import make_classification from sklearn.datasets import make_classification, load_iris
from stree import Stree, Snode from stree import Stree, Snode
def get_dataset(random_state=0): def get_dataset(random_state=0, n_classes=2):
X, y = make_classification( X, y = make_classification(
n_samples=1500, n_samples=1500,
n_features=3, n_features=3,
n_informative=3, n_informative=3,
n_redundant=0, n_redundant=0,
n_repeated=0, n_repeated=0,
n_classes=2, n_classes=n_classes,
n_clusters_per_class=2, n_clusters_per_class=2,
class_sep=1.5, class_sep=1.5,
flip_y=0, flip_y=0,
weights=[0.5, 0.5],
random_state=random_state, random_state=random_state,
) )
return X, y return X, y
@@ -104,9 +103,8 @@ class Stree_test(unittest.TestCase):
return res return res
def test_single_prediction(self): def test_single_prediction(self):
probs = [0.29026400766, 0.73105613, 0.0307635]
X, y = get_dataset(self._random_state) X, y = get_dataset(self._random_state)
for kernel, prob in zip(self._kernels, probs): for kernel in self._kernels:
clf = Stree(kernel=kernel, random_state=self._random_state) clf = Stree(kernel=kernel, random_state=self._random_state)
yp = clf.fit(X, y).predict((X[0, :].reshape(-1, X.shape[1]))) yp = clf.fit(X, y).predict((X[0, :].reshape(-1, X.shape[1])))
self.assertEqual(yp[0], y[0]) self.assertEqual(yp[0], y[0])
@@ -122,10 +120,12 @@ class Stree_test(unittest.TestCase):
def test_score(self): def test_score(self):
X, y = get_dataset(self._random_state) X, y = get_dataset(self._random_state)
for kernel, accuracy_expected in zip( accuracies = [
self._kernels, 0.9506666666666667,
[0.9506666666666667, 0.9606666666666667, 0.9433333333333334], 0.9606666666666667,
): 0.9433333333333334,
]
for kernel, accuracy_expected in zip(self._kernels, accuracies):
clf = Stree(random_state=self._random_state, kernel=kernel,) clf = Stree(random_state=self._random_state, kernel=kernel,)
clf.fit(X, y) clf.fit(X, y)
accuracy_score = clf.score(X, y) accuracy_score = clf.score(X, y)
@@ -134,38 +134,6 @@ class Stree_test(unittest.TestCase):
self.assertEqual(accuracy_score, accuracy_computed) self.assertEqual(accuracy_score, accuracy_computed)
self.assertAlmostEqual(accuracy_expected, accuracy_score) self.assertAlmostEqual(accuracy_expected, accuracy_score)
def test_single_predict_proba(self):
"""Check the element 28 probability of being 1
"""
decimals = 5
element = 28
probs = [0.29026400766, 0.73105613, 0.0307635]
X, y = get_dataset(self._random_state)
self.assertEqual(1, y[element])
for kernel, prob in zip(self._kernels, probs):
clf = Stree(kernel=kernel, random_state=self._random_state)
yp = clf.fit(X, y).predict_proba(
X[element, :].reshape(-1, X.shape[1])
)
self.assertAlmostEqual(
np.round(1 - prob, decimals), np.round(yp[0:, 0], decimals)
)
self.assertAlmostEqual(
round(prob, decimals), round(yp[0, 1], decimals), decimals
)
def test_multiple_predict_proba(self):
# First 27 elements the predictions are the same as the truth
num = 27
X, y = get_dataset(self._random_state)
for kernel in self._kernels:
clf = Stree(kernel=kernel, random_state=self._random_state)
clf.fit(X, y)
yp = clf.predict_proba(X[:num, :])
self.assertListEqual(
y[:num].tolist(), np.argmax(yp[:num], axis=1).tolist()
)
def test_single_vs_multiple_prediction(self): def test_single_vs_multiple_prediction(self):
"""Check if predicting sample by sample gives the same result as """Check if predicting sample by sample gives the same result as
predicting all samples at once predicting all samples at once
@@ -225,6 +193,11 @@ class Stree_test(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tclf.fit(*get_dataset(self._random_state)) tclf.fit(*get_dataset(self._random_state))
def test_exception_if_bogus_split_criteria(self):
tclf = Stree(split_criteria="duck")
with self.assertRaises(ValueError):
tclf.fit(*get_dataset(self._random_state))
def test_check_max_depth_is_positive_or_None(self): def test_check_max_depth_is_positive_or_None(self):
tcl = Stree() tcl = Stree()
self.assertIsNone(tcl.max_depth) self.assertIsNone(tcl.max_depth)
@@ -256,14 +229,56 @@ class Stree_test(unittest.TestCase):
self.assertIsNone(tcl_nosplit.tree_.get_down()) self.assertIsNone(tcl_nosplit.tree_.get_down())
self.assertIsNone(tcl_nosplit.tree_.get_up()) self.assertIsNone(tcl_nosplit.tree_.get_up())
def test_muticlass_dataset(self): def test_simple_muticlass_dataset(self):
for kernel in self._kernels: for kernel in self._kernels:
clf = Stree(kernel=kernel, random_state=self._random_state) clf = Stree(
px = [[1, 2], [3, 4], [5, 6]] kernel=kernel,
py = [1, 2, 3] split_criteria="max_samples",
random_state=self._random_state,
)
px = [[1, 2], [5, 6], [9, 10]]
py = [0, 1, 2]
clf.fit(px, py) clf.fit(px, py)
self.assertEqual(1.0, clf.score(px, py)) self.assertEqual(1.0, clf.score(px, py))
self.assertListEqual([1, 2, 3], clf.predict(px).tolist()) self.assertListEqual(py, clf.predict(px).tolist())
self.assertListEqual(py, clf.classes_.tolist())
def test_muticlass_dataset(self):
datasets = {
"Synt": get_dataset(random_state=self._random_state, n_classes=3),
"Iris": load_iris(return_X_y=True),
}
outcomes = {
"Synt": {
"max_samples linear": 0.9533333333333334,
"max_samples rbf": 0.836,
"max_samples poly": 0.9473333333333334,
"min_distance linear": 0.9533333333333334,
"min_distance rbf": 0.836,
"min_distance poly": 0.9473333333333334,
},
"Iris": {
"max_samples linear": 0.98,
"max_samples rbf": 1.0,
"max_samples poly": 1.0,
"min_distance linear": 0.98,
"min_distance rbf": 1.0,
"min_distance poly": 1.0,
},
}
for name, dataset in datasets.items():
px, py = dataset
for criteria in ["max_samples", "min_distance"]:
for kernel in self._kernels:
clf = Stree(
C=1e4,
max_iter=1e4,
kernel=kernel,
random_state=self._random_state,
)
clf.fit(px, py)
outcome = outcomes[name][f"{criteria} {kernel}"]
self.assertAlmostEqual(outcome, clf.score(px, py))
class Snode_test(unittest.TestCase): class Snode_test(unittest.TestCase):