#6 - Add multiclass support

Removed (by now) predict_proba. Created a notebook in jupyter
Added split_criteria parameter with min_distance and max_samples values
Refactor _distances
Refactor _split_criteria
Refactor _reorder_results
This commit is contained in:
2020-06-11 13:10:52 +02:00
parent 45510b43bc
commit f360a2640c
3 changed files with 156 additions and 199 deletions

11
main.py
View File

@@ -75,14 +75,3 @@ print(f"Took {time.time() - now:.2f} seconds to train")
print(clf)
print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}")
print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}")
proba = clf.predict_proba(Xtest)
print(
"Checking that we have correct probabilities, these are probabilities of "
"sample belonging to class 1"
)
res0 = proba[proba[:, 0] == 0]
res1 = proba[proba[:, 0] == 1]
print("++++++++++res0 > .8++++++++++++")
print(res0[res0[:, 1] > 0.8])
print("**********res1 < .4************")
print(res1[res1[:, 1] < 0.4])

View File

@@ -131,6 +131,7 @@ class Stree(BaseEstimator, ClassifierMixin):
tol: float = 1e-4,
degree: int = 3,
gamma="scale",
split_criteria="max_samples",
min_samples_split: int = 0,
):
self.max_iter = max_iter
@@ -142,17 +143,18 @@ class Stree(BaseEstimator, ClassifierMixin):
self.gamma = gamma
self.degree = degree
self.min_samples_split = min_samples_split
self.split_criteria = split_criteria
def _more_tags(self) -> dict:
"""Required by sklearn to tell that this estimator is a binary classifier
"""Required by sklearn to supply features of the classifier
:return: the tag required
:rtype: dict
"""
return {"binary_only": True, "requires_y": True}
return {"requires_y": True}
def _split_array(self, origin: np.array, down: np.array) -> list:
"""Split an array in two based on indices passed as down and its complement
"""Split an array in two based on indices (down) and its complement
:param origin: dataset to split
:type origin: np.array
@@ -163,8 +165,8 @@ class Stree(BaseEstimator, ClassifierMixin):
"""
up = ~down
return (
origin[up[:, 0]] if any(up) else None,
origin[down[:, 0]] if any(down) else None,
origin[up] if any(up) else None,
origin[down] if any(down) else None,
)
def _distances(self, node: Snode, data: np.ndarray) -> np.array:
@@ -178,27 +180,38 @@ class Stree(BaseEstimator, ClassifierMixin):
the hyperplane of the node
:rtype: np.array
"""
res = node._clf.decision_function(data)
if res.ndim == 1:
return np.expand_dims(res, 1)
elif res.shape[1] > 1:
# remove multiclass info
res = np.delete(res, slice(1, res.shape[1]), axis=1)
return res
return node._clf.decision_function(data)
def _split_criteria(self, data: np.array) -> np.array:
def _min_distance(self, data: np.array, _) -> np.array:
# chooses the lowest distance of every sample
indices = np.argmin(np.abs(data), axis=1)
return np.take(data, indices)
def _max_samples(self, data: np.array, y: np.array) -> np.array:
# select the class with max number of samples
_, samples = np.unique(y, return_counts=True)
selected = np.argmax(samples)
return data[:, selected]
def _split_criteria(self, data: np.array, node: Snode) -> np.array:
"""Set the criteria to split arrays
:param data: [description]
:param data: distances of samples to hyperplanes shape (m, nclasses)
if nclasses > 2 else (m,)
:type data: np.array
:return: [description]
:param node: node containing the svm classifier
:type node: Snode
:return: array of booleans of samples under or above zero
:rtype: np.array
"""
return (
data > 0
if data.shape[0] >= self.min_samples_split
else np.ones((data.shape[0], 1), dtype=bool)
)
if data.shape[0] < self.min_samples_split:
return np.ones((data.shape[0]), dtype=bool)
if data.ndim > 1:
# split criteria for multiclass
data = getattr(self, f"_{self.split_criteria}")(data, node._y)
res = data > 0
return res
def fit(
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
@@ -231,12 +244,19 @@ class Stree(BaseEstimator, ClassifierMixin):
f"Maximum depth has to be greater than 1... got (max_depth=\
{self.max_depth})"
)
if self.split_criteria not in ["min_distance", "max_samples"]:
raise ValueError(
f"split_criteria has to be min_distance or \
max_samples got ({self.split_criteria})"
)
check_classification_targets(y)
X, y = check_X_y(X, y)
sample_weight = _check_sample_weight(sample_weight, X)
check_classification_targets(y)
# Initialize computed parameters
self.classes_, y = np.unique(y, return_inverse=True)
self.n_classes_ = self.classes_.shape[0]
self.n_iter_ = self.max_iter
self.depth_ = 0
self.n_features_in_ = X.shape[1]
@@ -244,6 +264,52 @@ class Stree(BaseEstimator, ClassifierMixin):
self._build_predictor()
return self
def train(
self,
X: np.ndarray,
y: np.ndarray,
sample_weight: np.ndarray,
depth: int,
title: str,
) -> Snode:
"""Recursive function to split the original dataset into predictor
nodes (leaves)
:param X: samples dataset
:type X: np.ndarray
:param y: samples labels
:type y: np.ndarray
:param sample_weight: weight of samples. Rescale C per sample.
Hi weights force the classifier to put more emphasis on these points.
:type sample_weight: np.ndarray
:param depth: actual depth in the tree
:type depth: int
:param title: description of the node
:type title: str
:return: binary tree
:rtype: Snode
"""
if depth > self.__max_depth:
return None
if np.unique(y).shape[0] == 1:
# only 1 class => pure dataset
return Snode(None, X, y, title + ", <pure>")
# Train the model
clf = self._build_clf()
clf.fit(X, y, sample_weight=sample_weight)
node = Snode(clf, X, y, title)
self.depth_ = max(depth, self.depth_)
down = self._split_criteria(self._distances(node, X), node)
X_U, X_D = self._split_array(X, down)
y_u, y_d = self._split_array(y, down)
sw_u, sw_d = self._split_array(sample_weight, down)
if X_U is None or X_D is None:
# didn't part anything
return Snode(clf, X, y, title + ", <cgaf>")
node.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up"))
node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
return node
def _build_predictor(self):
"""Process the leaves to make them predictors
"""
@@ -278,52 +344,6 @@ class Stree(BaseEstimator, ClassifierMixin):
)
)
def train(
self,
X: np.ndarray,
y: np.ndarray,
sample_weight: np.ndarray,
depth: int,
title: str,
) -> Snode:
"""Recursive function to split the original dataset into predictor
nodes (leaves)
:param X: samples dataset
:type X: np.ndarray
:param y: samples labels
:type y: np.ndarray
:param sample_weight: weight of samples. Rescale C per sample.
Hi weights force the classifier to put more emphasis on these points.
:type sample_weight: np.ndarray
:param depth: actual depth in the tree
:type depth: int
:param title: description of the node
:type title: str
:return: binary tree
:rtype: Snode
"""
if depth > self.__max_depth:
return None
if np.unique(y).shape[0] == 1:
# only 1 class => pure dataset
return Snode(None, X, y, title + ", <pure>")
# Train the model
clf = self._build_clf()
clf.fit(X, y, sample_weight=sample_weight)
tree = Snode(clf, X, y, title)
self.depth_ = max(depth, self.depth_)
down = self._split_criteria(self._distances(tree, X))
X_U, X_D = self._split_array(X, down)
y_u, y_d = self._split_array(y, down)
sw_u, sw_d = self._split_array(sample_weight, down)
if X_U is None or X_D is None:
# didn't part anything
return Snode(clf, X, y, title + ", <cgaf>")
tree.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up"))
tree.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
return tree
def _reorder_results(self, y: np.array, indices: np.array) -> np.array:
"""Reorder an array based on the array of indices passed
@@ -334,10 +354,6 @@ class Stree(BaseEstimator, ClassifierMixin):
:return: array y ordered
:rtype: np.array
"""
if y.ndim > 1 and y.shape[1] > 1:
# if predict_proba return np.array of floats
y_ordered = np.zeros(y.shape, dtype=float)
else:
# return array of same type given in y
y_ordered = y.copy()
indices = indices.astype(int)
@@ -363,11 +379,11 @@ class Stree(BaseEstimator, ClassifierMixin):
# set a class for every sample in dataset
prediction = np.full((xp.shape[0], 1), node._class)
return prediction, indices
down = self._split_criteria(self._distances(node, xp))
X_U, X_D = self._split_array(xp, down)
down = self._split_criteria(self._distances(node, xp), node)
x_u, x_d = self._split_array(xp, down)
i_u, i_d = self._split_array(indices, down)
prx_u, prin_u = predict_class(X_U, i_u, node.get_up())
prx_d, prin_d = predict_class(X_D, i_d, node.get_down())
prx_u, prin_u = predict_class(x_u, i_u, node.get_up())
prx_d, prin_d = predict_class(x_d, i_d, node.get_down())
return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
# sklearn check
@@ -383,68 +399,6 @@ class Stree(BaseEstimator, ClassifierMixin):
)
return self.classes_[result]
def predict_proba(self, X: np.array) -> np.array:
"""Computes an approximation of the probability of samples belonging to
class 0 and 1
:param X: dataset
:type X: np.array
:return: array array of shape (m, num_classes), probability of being
each class
:rtype: np.array
"""
def predict_class(
xp: np.array, indices: np.array, dist: np.array, node: Snode
) -> np.array:
"""Run the tree to compute predictions
:param xp: subdataset of samples
:type xp: np.array
:param indices: indices of subdataset samples to rebuild original
order
:type indices: np.array
:param dist: distances of every sample to the hyperplane or the
father node
:type dist: np.array
:param node: node of the leaf with the class
:type node: Snode
:return: array of labels and distances, array of indices
:rtype: np.array
"""
if xp is None:
return [], []
if node.is_leaf():
# set a class for every sample in dataset
prediction = np.full((xp.shape[0], 1), node._class)
prediction_proba = dist
return np.append(prediction, prediction_proba, axis=1), indices
distances = self._distances(node, xp)
down = self._split_criteria(distances)
X_U, X_D = self._split_array(xp, down)
i_u, i_d = self._split_array(indices, down)
di_u, di_d = self._split_array(distances, down)
prx_u, prin_u = predict_class(X_U, i_u, di_u, node.get_up())
prx_d, prin_d = predict_class(X_D, i_d, di_d, node.get_down())
return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
# sklearn check
check_is_fitted(self, ["tree_"])
# Input validation
X = check_array(X)
# setup prediction & make it happen
indices = np.arange(X.shape[0])
empty_dist = np.empty((X.shape[0], 1), dtype=float)
result, indices = predict_class(X, indices, empty_dist, self.tree_)
result = result.reshape(X.shape[0], 2)
# Turn distances to hyperplane into probabilities based on fitting
# distances of samples to its hyperplane that classified them, to the
# sigmoid function
# Probability of being 1
result[:, 1] = 1 / (1 + np.exp(-result[:, 1]))
# Probability of being 0
result[:, 0] = 1 - result[:, 1]
return self._reorder_results(result, indices)
def score(
self, X: np.array, y: np.array, sample_weight: np.array = None
) -> float:
@@ -473,12 +427,11 @@ class Stree(BaseEstimator, ClassifierMixin):
score = differing_labels == 0
else:
score = y_true == y_pred
return _weighted_sum(score, sample_weight, normalize=True)
def __iter__(self) -> Siterator:
"""Create an iterator to be able to visit the nodes of the tree in preorder,
can make a list with all the nodes in preorder
"""Create an iterator to be able to visit the nodes of the tree in
preorder, can make a list with all the nodes in preorder
:return: an iterator, can for i in... and list(...)
:rtype: Siterator

View File

@@ -2,23 +2,22 @@ import os
import unittest
import numpy as np
from sklearn.datasets import make_classification
from sklearn.datasets import make_classification, load_iris
from stree import Stree, Snode
def get_dataset(random_state=0):
def get_dataset(random_state=0, n_classes=2):
X, y = make_classification(
n_samples=1500,
n_features=3,
n_informative=3,
n_redundant=0,
n_repeated=0,
n_classes=2,
n_classes=n_classes,
n_clusters_per_class=2,
class_sep=1.5,
flip_y=0,
weights=[0.5, 0.5],
random_state=random_state,
)
return X, y
@@ -104,9 +103,8 @@ class Stree_test(unittest.TestCase):
return res
def test_single_prediction(self):
probs = [0.29026400766, 0.73105613, 0.0307635]
X, y = get_dataset(self._random_state)
for kernel, prob in zip(self._kernels, probs):
for kernel in self._kernels:
clf = Stree(kernel=kernel, random_state=self._random_state)
yp = clf.fit(X, y).predict((X[0, :].reshape(-1, X.shape[1])))
self.assertEqual(yp[0], y[0])
@@ -122,10 +120,12 @@ class Stree_test(unittest.TestCase):
def test_score(self):
X, y = get_dataset(self._random_state)
for kernel, accuracy_expected in zip(
self._kernels,
[0.9506666666666667, 0.9606666666666667, 0.9433333333333334],
):
accuracies = [
0.9506666666666667,
0.9606666666666667,
0.9433333333333334,
]
for kernel, accuracy_expected in zip(self._kernels, accuracies):
clf = Stree(random_state=self._random_state, kernel=kernel,)
clf.fit(X, y)
accuracy_score = clf.score(X, y)
@@ -134,38 +134,6 @@ class Stree_test(unittest.TestCase):
self.assertEqual(accuracy_score, accuracy_computed)
self.assertAlmostEqual(accuracy_expected, accuracy_score)
def test_single_predict_proba(self):
"""Check the element 28 probability of being 1
"""
decimals = 5
element = 28
probs = [0.29026400766, 0.73105613, 0.0307635]
X, y = get_dataset(self._random_state)
self.assertEqual(1, y[element])
for kernel, prob in zip(self._kernels, probs):
clf = Stree(kernel=kernel, random_state=self._random_state)
yp = clf.fit(X, y).predict_proba(
X[element, :].reshape(-1, X.shape[1])
)
self.assertAlmostEqual(
np.round(1 - prob, decimals), np.round(yp[0:, 0], decimals)
)
self.assertAlmostEqual(
round(prob, decimals), round(yp[0, 1], decimals), decimals
)
def test_multiple_predict_proba(self):
# First 27 elements the predictions are the same as the truth
num = 27
X, y = get_dataset(self._random_state)
for kernel in self._kernels:
clf = Stree(kernel=kernel, random_state=self._random_state)
clf.fit(X, y)
yp = clf.predict_proba(X[:num, :])
self.assertListEqual(
y[:num].tolist(), np.argmax(yp[:num], axis=1).tolist()
)
def test_single_vs_multiple_prediction(self):
"""Check if predicting sample by sample gives the same result as
predicting all samples at once
@@ -225,6 +193,11 @@ class Stree_test(unittest.TestCase):
with self.assertRaises(ValueError):
tclf.fit(*get_dataset(self._random_state))
def test_exception_if_bogus_split_criteria(self):
tclf = Stree(split_criteria="duck")
with self.assertRaises(ValueError):
tclf.fit(*get_dataset(self._random_state))
def test_check_max_depth_is_positive_or_None(self):
tcl = Stree()
self.assertIsNone(tcl.max_depth)
@@ -256,14 +229,56 @@ class Stree_test(unittest.TestCase):
self.assertIsNone(tcl_nosplit.tree_.get_down())
self.assertIsNone(tcl_nosplit.tree_.get_up())
def test_muticlass_dataset(self):
def test_simple_muticlass_dataset(self):
for kernel in self._kernels:
clf = Stree(kernel=kernel, random_state=self._random_state)
px = [[1, 2], [3, 4], [5, 6]]
py = [1, 2, 3]
clf = Stree(
kernel=kernel,
split_criteria="max_samples",
random_state=self._random_state,
)
px = [[1, 2], [5, 6], [9, 10]]
py = [0, 1, 2]
clf.fit(px, py)
self.assertEqual(1.0, clf.score(px, py))
self.assertListEqual([1, 2, 3], clf.predict(px).tolist())
self.assertListEqual(py, clf.predict(px).tolist())
self.assertListEqual(py, clf.classes_.tolist())
def test_muticlass_dataset(self):
datasets = {
"Synt": get_dataset(random_state=self._random_state, n_classes=3),
"Iris": load_iris(return_X_y=True),
}
outcomes = {
"Synt": {
"max_samples linear": 0.9533333333333334,
"max_samples rbf": 0.836,
"max_samples poly": 0.9473333333333334,
"min_distance linear": 0.9533333333333334,
"min_distance rbf": 0.836,
"min_distance poly": 0.9473333333333334,
},
"Iris": {
"max_samples linear": 0.98,
"max_samples rbf": 1.0,
"max_samples poly": 1.0,
"min_distance linear": 0.98,
"min_distance rbf": 1.0,
"min_distance poly": 1.0,
},
}
for name, dataset in datasets.items():
px, py = dataset
for criteria in ["max_samples", "min_distance"]:
for kernel in self._kernels:
clf = Stree(
C=1e4,
max_iter=1e4,
kernel=kernel,
random_state=self._random_state,
)
clf.fit(px, py)
outcome = outcomes[name][f"{criteria} {kernel}"]
self.assertAlmostEqual(outcome, clf.score(px, py))
class Snode_test(unittest.TestCase):