#2 Refactor Stree & create Splitter

Add and test splitter parameter
This commit is contained in:
2020-06-15 00:22:57 +02:00
parent 502ee72799
commit c94bc068bd
7 changed files with 529 additions and 176 deletions

File diff suppressed because one or more lines are too long

View File

@@ -9,12 +9,14 @@ Build an oblique tree classifier based on SVM Trees
import os
import numbers
import random
import warnings
from itertools import combinations
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC, LinearSVC
from sklearn.utils import check_consistent_length
from sklearn.utils.multiclass import check_classification_targets
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.validation import (
check_X_y,
check_array,
@@ -134,6 +136,168 @@ class Siterator:
return node
class Splitter:
def __init__(
self,
clf: SVC = None,
criterion: str = None,
splitter_type: str = None,
criteria: str = None,
min_samples_split: int = None,
random_state=None,
):
self._clf = clf
self._random_state = random_state
if random_state is not None:
random.seed(random_state)
self._criterion = criterion
self._min_samples_split = min_samples_split
self._criteria = criteria
self._splitter_type = splitter_type
if clf is None:
raise ValueError(f"clf has to be a sklearn estimator, got({clf})")
if criterion not in ["gini", "entropy"]:
raise ValueError(
f"criterion must be gini or entropy got({criterion})"
)
if criteria not in ["min_distance", "max_samples"]:
raise ValueError(
f"split_criteria has to be min_distance or \
max_samples got ({criteria})"
)
if splitter_type not in ["random", "best"]:
raise ValueError(
f"splitter must be either random or best got({splitter_type})"
)
self.criterion_function = getattr(self, f"_{self._criterion}")
self.decision_criteria = getattr(self, f"_{self._criteria}")
def impurity(self, y: np.array) -> np.array:
return self.criterion_function(y)
@staticmethod
def _gini(y: np.array) -> float:
_, count = np.unique(y, return_counts=True)
return 1 - np.sum(np.square(count / np.sum(count)))
@staticmethod
def _entropy(y: np.array) -> float:
_, count = np.unique(y, return_counts=True)
proportion = count / np.sum(count)
return -np.sum(proportion * np.log2(proportion))
def information_gain(
self, labels_up: np.array, labels_dn: np.array
) -> float:
card_up = labels_up.shape[0]
card_dn = labels_dn.shape[0]
samples = card_up + card_dn
up = card_up / samples * self.criterion_function(labels_up)
dn = card_dn / samples * self.criterion_function(labels_dn)
return up + dn
def _select_best_set(
self, dataset: np.array, labels: np.array, features_sets: list
) -> list:
min_impurity = 1
selected = None
warnings.filterwarnings("ignore", category=ConvergenceWarning)
for feature_set in features_sets:
self._clf.fit(dataset[:, feature_set], labels)
node = Snode(
self._clf, dataset, labels, feature_set, 0.0, "subset"
)
self.partition(dataset, node)
y1, y2 = self.part(labels)
impurity = self.information_gain(y1, y2)
if impurity < min_impurity:
min_impurity = impurity
selected = feature_set
return selected
def _get_subspaces_set(
self, dataset: np.array, labels: np.array, max_features: int
) -> np.array:
features = range(dataset.shape[1])
features_sets = list(combinations(features, max_features))
if len(features_sets) > 1:
if self._splitter_type == "random":
return features_sets[random.randint(0, len(features_sets) - 1)]
else:
return self._select_best_set(dataset, labels, features_sets)
else:
return features_sets[0]
def get_subspace(
self, dataset: np.array, labels: np.array, max_features: int
) -> list:
"""Return the best subspace to make a split
"""
indices = self._get_subspaces_set(dataset, labels, max_features)
return dataset[:, indices], indices
@staticmethod
def _min_distance(data: np.array, _) -> np.array:
# chooses the lowest distance of every sample
indices = np.argmin(np.abs(data), axis=1)
return np.array(
[data[x, y] for x, y in zip(range(len(data[:, 0])), indices)]
)
@staticmethod
def _max_samples(data: np.array, y: np.array) -> np.array:
# select the class with max number of samples
_, samples = np.unique(y, return_counts=True)
selected = np.argmax(samples)
return data[:, selected]
def partition(self, samples: np.array, node: Snode):
"""Set the criteria to split arrays
"""
data = self._distances(node, samples)
if data.shape[0] < self._min_samples_split:
self._down = np.ones((data.shape[0]), dtype=bool)
return
if data.ndim > 1:
# split criteria for multiclass
data = self.decision_criteria(data, node._y)
self._down = data > 0
def _distances(self, node: Snode, data: np.ndarray) -> np.array:
"""Compute distances of the samples to the hyperplane of the node
:param node: node containing the svm classifier
:type node: Snode
:param data: samples to find out distance to hyperplane
:type data: np.ndarray
:return: array of shape (m, 1) with the distances of every sample to
the hyperplane of the node
:rtype: np.array
"""
return node._clf.decision_function(data[:, node._features])
def part(self, origin: np.array) -> list:
"""Split an array in two based on indices (down) and its complement
:param origin: dataset to split
:type origin: np.array
:param down: indices to use to split array
:type down: np.array
:return: list with two splits of the array
:rtype: list
"""
up = ~self._down
return [
origin[up] if any(up) else None,
origin[self._down] if any(self._down) else None,
]
class Stree(BaseEstimator, ClassifierMixin):
"""Estimator that is based on binary trees of svm nodes
can deal with sample_weights in predict, used in boosting sklearn methods
@@ -156,6 +320,7 @@ class Stree(BaseEstimator, ClassifierMixin):
criterion: str = "gini",
min_samples_split: int = 0,
max_features=None,
splitter: str = "random",
):
self.max_iter = max_iter
self.C = C
@@ -169,6 +334,7 @@ class Stree(BaseEstimator, ClassifierMixin):
self.split_criteria = split_criteria
self.max_features = max_features
self.criterion = criterion
self.splitter = splitter
def _more_tags(self) -> dict:
"""Required by sklearn to supply features of the classifier
@@ -178,68 +344,6 @@ class Stree(BaseEstimator, ClassifierMixin):
"""
return {"requires_y": True}
def _split_array(self, origin: np.array, down: np.array) -> list:
"""Split an array in two based on indices (down) and its complement
:param origin: dataset to split
:type origin: np.array
:param down: indices to use to split array
:type down: np.array
:return: list with two splits of the array
:rtype: list
"""
up = ~down
return [
origin[up] if any(up) else None,
origin[down] if any(down) else None,
]
def _distances(self, node: Snode, data: np.ndarray) -> np.array:
"""Compute distances of the samples to the hyperplane of the node
:param node: node containing the svm classifier
:type node: Snode
:param data: samples to find out distance to hyperplane
:type data: np.ndarray
:return: array of shape (m, 1) with the distances of every sample to
the hyperplane of the node
:rtype: np.array
"""
return node._clf.decision_function(data[:, node._features])
def _min_distance(self, data: np.array, _) -> np.array:
# chooses the lowest distance of every sample
indices = np.argmin(np.abs(data), axis=1)
return np.array(
[data[x, y] for x, y in zip(range(len(data[:, 0])), indices)]
)
def _max_samples(self, data: np.array, y: np.array) -> np.array:
# select the class with max number of samples
_, samples = np.unique(y, return_counts=True)
selected = np.argmax(samples)
return data[:, selected]
def _split_criteria(self, data: np.array, node: Snode) -> np.array:
"""Set the criteria to split arrays
:param data: distances of samples to hyperplanes shape (m, nclasses)
if nclasses > 2 else (m,)
:type data: np.array
:param node: node containing the svm classifier
:type node: Snode
:return: array of booleans of samples under or above zero
:rtype: np.array
"""
if data.shape[0] < self.min_samples_split:
return np.ones((data.shape[0]), dtype=bool)
if data.ndim > 1:
# split criteria for multiclass
data = getattr(self, f"_{self.split_criteria}")(data, node._y)
res = data > 0
return res
def fit(
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
) -> "Stree":
@@ -271,21 +375,20 @@ class Stree(BaseEstimator, ClassifierMixin):
f"Maximum depth has to be greater than 1... got (max_depth=\
{self.max_depth})"
)
if self.split_criteria not in ["min_distance", "max_samples"]:
raise ValueError(
f"split_criteria has to be min_distance or \
max_samples got ({self.split_criteria})"
)
if self.criterion not in ["gini", "entropy"]:
raise ValueError(
f"criterion must be gini or entropy got({self.criterion})"
)
check_classification_targets(y)
X, y = check_X_y(X, y)
sample_weight = _check_sample_weight(sample_weight, X)
check_classification_targets(y)
# Initialize computed parameters
self.splitter_ = Splitter(
clf=self._build_clf(),
criterion=self.criterion,
splitter_type=self.splitter,
criteria=self.split_criteria,
random_state=self.random_state,
min_samples_split=self.min_samples_split,
)
if self.random_state is not None:
random.seed(self.random_state)
self.classes_, y = np.unique(y, return_inverse=True)
@@ -295,7 +398,6 @@ class Stree(BaseEstimator, ClassifierMixin):
self.n_features_ = X.shape[1]
self.n_features_in_ = X.shape[1]
self.max_features_ = self._initialize_max_features()
self.criterion_function_ = getattr(self, f"_{self.criterion}")
self.tree_ = self.train(X, y, sample_weight, 1, "root")
self._build_predictor()
return self
@@ -339,15 +441,15 @@ class Stree(BaseEstimator, ClassifierMixin):
)
# Train the model
clf = self._build_clf()
Xs, features = self._get_subspace(X)
Xs, features = self.splitter_.get_subspace(X, y, self.max_features_)
clf.fit(Xs, y, sample_weight=sample_weight)
impurity = self.criterion_function_(y)
impurity = self.splitter_.impurity(y)
node = Snode(clf, X, y, features, impurity, title)
self.depth_ = max(depth, self.depth_)
down = self._split_criteria(self._distances(node, X), node)
X_U, X_D = self._split_array(X, down)
y_u, y_d = self._split_array(y, down)
sw_u, sw_d = self._split_array(sample_weight, down)
self.splitter_.partition(X, node)
X_U, X_D = self.splitter_.part(X)
y_u, y_d = self.splitter_.part(y)
sw_u, sw_d = self.splitter_.part(sample_weight)
if X_U is None or X_D is None:
# didn't part anything
return Snode(
@@ -431,9 +533,9 @@ class Stree(BaseEstimator, ClassifierMixin):
# set a class for every sample in dataset
prediction = np.full((xp.shape[0], 1), node._class)
return prediction, indices
down = self._split_criteria(self._distances(node, xp), node)
x_u, x_d = self._split_array(xp, down)
i_u, i_d = self._split_array(indices, down)
self.splitter_.partition(xp, node)
x_u, x_d = self.splitter_.part(xp)
i_u, i_d = self.splitter_.part(indices)
prx_u, prin_u = predict_class(x_u, i_u, node.get_up())
prx_d, prin_d = predict_class(x_d, i_d, node.get_down())
return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
@@ -536,29 +638,3 @@ class Stree(BaseEstimator, ClassifierMixin):
f"got ({self.max_features})"
)
return max_features
@staticmethod
def _gini(y: np.array) -> float:
_, count = np.unique(y, return_counts=True)
return 1 - np.sum(np.square(count / np.sum(count)))
@staticmethod
def _entropy(y: np.array) -> float:
_, count = np.unique(y, return_counts=True)
proportion = count / np.sum(count)
return -np.sum(proportion * np.log2(proportion))
def _get_subspace(self, dataset: np.array) -> list:
"""Return the best subspace to make a split
"""
def get_subspaces_set(dataset: np.array) -> np.array:
features = range(dataset.shape[1])
features_sets = list(combinations(features, self.max_features_))
if len(features_sets) > 1:
return features_sets[random.randint(0, len(features_sets) - 1)]
else:
return features_sets[0]
indices = get_subspaces_set(dataset)
return dataset[:, indices], indices

View File

@@ -1,3 +1,3 @@
from .Strees import Stree, Snode, Siterator
from .Strees import Stree, Snode, Siterator, Splitter
__all__ = ["Stree", "Snode", "Siterator"]
__all__ = ["Stree", "Snode", "Siterator", "Splitter"]

View File

@@ -0,0 +1,142 @@
import os
import unittest
import numpy as np
from sklearn.svm import LinearSVC
from stree import Splitter
from .utils import load_dataset
class Splitter_test(unittest.TestCase):
def __init__(self, *args, **kwargs):
self._random_state = 1
super().__init__(*args, **kwargs)
def build(
self,
clf=LinearSVC(),
min_samples_split=0,
splitter_type="random",
criterion="gini",
criteria="min_distance",
random_state=None,
):
return Splitter(
clf=clf,
min_samples_split=min_samples_split,
splitter_type=splitter_type,
criterion=criterion,
criteria=criteria,
random_state=random_state,
)
@classmethod
def setUp(cls):
os.environ["TESTING"] = "1"
def test_init(self):
with self.assertRaises(ValueError):
self.build(criterion="duck")
with self.assertRaises(ValueError):
self.build(splitter_type="duck")
with self.assertRaises(ValueError):
self.build(criteria="duck")
with self.assertRaises(ValueError):
self.build(clf=None)
for splitter_type in ["best", "random"]:
for criterion in ["gini", "entropy"]:
for criteria in ["min_distance", "max_samples"]:
tcl = self.build(
splitter_type=splitter_type,
criterion=criterion,
criteria=criteria,
)
self.assertEqual(splitter_type, tcl._splitter_type)
self.assertEqual(criterion, tcl._criterion)
self.assertEqual(criteria, tcl._criteria)
def test_gini(self):
y = [0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
expected = 0.48
self.assertEqual(expected, Splitter._gini(y))
tcl = self.build(criterion="gini")
self.assertEqual(expected, tcl.criterion_function(y))
def test_entropy(self):
y = [0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
expected = 0.9709505944546686
self.assertAlmostEqual(expected, Splitter._entropy(y))
tcl = self.build(criterion="entropy")
self.assertEqual(expected, tcl.criterion_function(y))
def test_information_gain(self):
yu = np.array([0, 1, 1, 1, 1, 1])
yd = np.array([0, 0, 0, 1])
values_expected = [
("gini", 0.31666666666666665),
("entropy", 0.7145247027726656),
]
for criterion, expected in values_expected:
tcl = self.build(criterion=criterion)
computed = tcl.information_gain(yu, yd)
self.assertAlmostEqual(expected, computed)
def test_max_samples(self):
tcl = self.build(criteria="max_samples")
data = np.array(
[
[-0.1, 0.2, -0.3],
[0.7, 0.01, -0.1],
[0.7, -0.9, 0.5],
[0.1, 0.2, 0.3],
]
)
expected = np.array([0.2, 0.01, -0.9, 0.2])
y = [1, 2, 1, 0]
computed = tcl._max_samples(data, y)
self.assertEqual((4,), computed.shape)
self.assertListEqual(expected.tolist(), computed.tolist())
def test_min_distance(self):
tcl = self.build()
data = np.array(
[
[-0.1, 0.2, -0.3],
[0.7, 0.01, -0.1],
[0.7, -0.9, 0.5],
[0.1, 0.2, 0.3],
]
)
expected = np.array([-0.1, 0.01, 0.5, 0.1])
computed = tcl._min_distance(data, None)
self.assertEqual((4,), computed.shape)
self.assertListEqual(expected.tolist(), computed.tolist())
def test_splitter_parameter(self):
expected_values = [
[1, 7, 9],
[1, 7, 9],
[1, 7, 9],
[1, 7, 9],
[0, 5, 6],
[0, 5, 6],
[0, 5, 6],
[0, 5, 6],
]
X, y = load_dataset(self._random_state, n_features=12)
for splitter_type in ["best", "random"]:
for criterion in ["gini", "entropy"]:
for criteria in ["min_distance", "max_samples"]:
tcl = self.build(
splitter_type=splitter_type,
criterion=criterion,
criteria=criteria,
random_state=self._random_state,
)
expected = expected_values.pop(0)
dataset, computed = tcl.get_subspace(X, y, max_features=3)
self.assertListEqual(expected, list(computed))
self.assertListEqual(
X[:, computed].tolist(), dataset.tolist()
)

View File

@@ -204,13 +204,11 @@ class Stree_test(unittest.TestCase):
self.assertEqual(0, len(list(tcl)))
def test_min_samples_split(self):
tcl_split = Stree(min_samples_split=3)
tcl_nosplit = Stree(min_samples_split=4)
dataset = [[1], [2], [3]], [1, 1, 0]
tcl_split.fit(*dataset)
tcl_split = Stree(min_samples_split=3).fit(*dataset)
self.assertIsNotNone(tcl_split.tree_.get_down())
self.assertIsNotNone(tcl_split.tree_.get_up())
tcl_nosplit.fit(*dataset)
tcl_nosplit = Stree(min_samples_split=4).fit(*dataset)
self.assertIsNone(tcl_nosplit.tree_.get_down())
self.assertIsNone(tcl_nosplit.tree_.get_up())
@@ -265,37 +263,6 @@ class Stree_test(unittest.TestCase):
outcome = outcomes[name][f"{criteria} {kernel}"]
self.assertAlmostEqual(outcome, clf.score(px, py))
def test_min_distance(self):
clf = Stree()
data = np.array(
[
[-0.1, 0.2, -0.3],
[0.7, 0.01, -0.1],
[0.7, -0.9, 0.5],
[0.1, 0.2, 0.3],
]
)
expected = np.array([-0.1, 0.01, 0.5, 0.1])
computed = clf._min_distance(data, None)
self.assertEqual((4,), computed.shape)
self.assertListEqual(expected.tolist(), computed.tolist())
def test_max_samples(self):
clf = Stree()
data = np.array(
[
[-0.1, 0.2, -0.3],
[0.7, 0.01, -0.1],
[0.7, -0.9, 0.5],
[0.1, 0.2, 0.3],
]
)
expected = np.array([0.2, 0.01, -0.9, 0.2])
y = [1, 2, 1, 0]
computed = clf._max_samples(data, y)
self.assertEqual((4,), computed.shape)
self.assertListEqual(expected.tolist(), computed.tolist())
def test_max_features(self):
n_features = 16
expected_values = [
@@ -334,7 +301,9 @@ class Stree_test(unittest.TestCase):
for max_features, expected in expected_values:
clf.set_params(**dict(max_features=max_features))
clf.fit(dataset, y)
computed, indices = clf._get_subspace(dataset)
computed, indices = clf.splitter_.get_subspace(
dataset, y, clf.max_features_
)
self.assertListEqual(
dataset[:, indices].tolist(), computed.tolist()
)
@@ -345,22 +314,6 @@ class Stree_test(unittest.TestCase):
with self.assertRaises(ValueError):
clf.fit(*load_dataset())
def test_gini(self):
y = [0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
expected = 0.48
self.assertEqual(expected, Stree._gini(y))
clf = Stree(criterion="gini")
clf.fit(*load_dataset())
self.assertEqual(expected, clf.criterion_function_(y))
def test_entropy(self):
y = [0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
expected = 0.9709505944546686
self.assertAlmostEqual(expected, Stree._entropy(y))
clf = Stree(criterion="entropy")
clf.fit(*load_dataset())
self.assertEqual(expected, clf.criterion_function_(y))
def test_predict_feature_dimensions(self):
X = np.random.rand(10, 5)
y = np.random.randint(0, 2, 10)
@@ -374,3 +327,8 @@ class Stree_test(unittest.TestCase):
clf = Stree(random_state=self._random_state, max_features=2)
clf.fit(X, y)
self.assertAlmostEqual(0.9426666666666667, clf.score(X, y))
def test_bogus_splitter_parameter(self):
clf = Stree(splitter="duck")
with self.assertRaises(ValueError):
clf.fit(*load_dataset())

View File

@@ -1,4 +1,5 @@
from .Stree_test import Stree_test
from .Snode_test import Snode_test
from .Splitter_test import Splitter_test
__all__ = ["Stree_test", "Snode_test"]
__all__ = ["Stree_test", "Snode_test", "Splitter_test"]

View File

@@ -1,10 +1,10 @@
from sklearn.datasets import make_classification
def load_dataset(random_state=0, n_classes=2):
def load_dataset(random_state=0, n_classes=2, n_features=3):
X, y = make_classification(
n_samples=1500,
n_features=3,
n_features=n_features,
n_informative=3,
n_redundant=0,
n_repeated=0,