Become sklearn classifier

This commit is contained in:
Ricardo Montañana Gómez 2020-06-30 11:14:05 +02:00
parent 580c93d92a
commit 98a28cd271
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
4 changed files with 126 additions and 52 deletions

View File

@ -9,6 +9,4 @@ exclude_lines =
raise NotImplementedError raise NotImplementedError
if __name__ == .__main__.: if __name__ == .__main__.:
ignore_errors = True ignore_errors = True
omit = omit =
odte/tests/*
odte/__init__.py

View File

@ -219,7 +219,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Oblique Decision Tree Ensemble\n", "# Oblique Decision Tree Ensemble\n",
"odte = Odte(random_state=random_state, n_estimators=10, max_features=None)" "odte = Odte(random_state=random_state, n_estimators=10, max_features=\"auto\")"
] ]
}, },
{ {

View File

@ -6,13 +6,15 @@ __version__ = "0.1"
Build a forest of oblique trees based on STree Build a forest of oblique trees based on STree
""" """
import random
from typing import Union
from itertools import combinations
import numpy as np import numpy as np
from sklearn.utils import check_consistent_length from sklearn.utils import check_consistent_length
from sklearn.metrics._classification import _weighted_sum, _check_targets from sklearn.metrics._classification import _weighted_sum, _check_targets
from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.multiclass import check_classification_targets
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import clone, ClassifierMixin
from scipy.stats import mode from sklearn.ensemble import BaseEnsemble
from sklearn.utils.validation import ( from sklearn.utils.validation import (
check_X_y, check_X_y,
check_array, check_array,
@ -23,44 +25,30 @@ from sklearn.utils.validation import (
from stree import Stree from stree import Stree
class Odte(BaseEstimator, ClassifierMixin): class Odte(BaseEnsemble, ClassifierMixin):
def __init__( def __init__(
self, self,
base_estimator=None,
random_state: int = None, random_state: int = None,
C: int = 1, max_features: Union[str, int, float] = 1.0,
max_samples: Union[int, float] = None,
n_estimators: int = 100, n_estimators: int = 100,
max_iter: int = 1000,
max_depth: int = None,
min_samples_split: int = 0,
split_criteria: str = "min_distance",
criterion: str = "gini",
tol: float = 1e-4,
gamma="scale",
degree: int = 3,
kernel: str = "linear",
max_features="auto",
max_samples=None,
splitter: str = "random",
): ):
base_estimator = (
Stree(random_state=random_state)
if base_estimator is None
else base_estimator
)
super().__init__(
base_estimator=base_estimator, n_estimators=n_estimators,
)
self.n_estimators = n_estimators self.n_estimators = n_estimators
self.random_state = random_state self.random_state = random_state
self.max_features = max_features self.max_features = max_features
self.max_samples = max_samples # size of bootstrap self.max_samples = max_samples # size of bootstrap
self.estimator_params = dict(
C=C, def _more_tags(self) -> dict:
random_state=random_state, return {"requires_y": True}
min_samples_split=min_samples_split,
max_depth=max_depth,
split_criteria=split_criteria,
criterion=criterion,
kernel=kernel,
max_iter=max_iter,
tol=tol,
degree=degree,
gamma=gamma,
splitter=splitter,
max_features=max_features,
)
def _initialize_random(self) -> np.random.mtrand.RandomState: def _initialize_random(self) -> np.random.mtrand.RandomState:
if self.random_state is None: if self.random_state is None:
@ -77,6 +65,12 @@ class Odte(BaseEstimator, ClassifierMixin):
else: else:
return sample_weight.copy() return sample_weight.copy()
def _validate_estimator(self):
"""Check the estimator and set the base_estimator_ attribute."""
super()._validate_estimator(
default=Stree(random_state=self.random_state)
)
def fit( def fit(
self, X: np.array, y: np.array, sample_weight: np.array = None self, X: np.array, y: np.array, sample_weight: np.array = None
) -> "Odte": ) -> "Odte":
@ -89,9 +83,16 @@ class Odte(BaseEstimator, ClassifierMixin):
# the rest of parameters are checked in estimator # the rest of parameters are checked in estimator
check_classification_targets(y) check_classification_targets(y)
X, y = check_X_y(X, y) X, y = check_X_y(X, y)
sample_weight = _check_sample_weight(sample_weight, X) sample_weight = _check_sample_weight(
sample_weight, X, dtype=np.float64
)
check_classification_targets(y) check_classification_targets(y)
# Initialize computed parameters # Initialize computed parameters
# Build the estimator
self.n_features_in_ = X.shape[1]
self.n_features = X.shape[1]
self.max_features_ = self._initialize_max_features()
self._validate_estimator()
self.classes_, y = np.unique(y, return_inverse=True) self.classes_, y = np.unique(y, return_inverse=True)
self.n_classes_ = self.classes_.shape[0] self.n_classes_ = self.classes_.shape[0]
self.estimators_ = [] self.estimators_ = []
@ -107,15 +108,17 @@ class Odte(BaseEstimator, ClassifierMixin):
boot_samples = self._get_bootstrap_n_samples(n_samples) boot_samples = self._get_bootstrap_n_samples(n_samples)
for _ in range(self.n_estimators): for _ in range(self.n_estimators):
# Build clf # Build clf
clf = Stree().set_params(**self.estimator_params) clf = clone(self.base_estimator_)
# clf.set_params(**self.estimator_params)
self.estimators_.append(clf) self.estimators_.append(clf)
# bootstrap # bootstrap
indices = random_box.randint(0, n_samples, boot_samples) indices = random_box.randint(0, n_samples, boot_samples)
# update weights with the chosen samples # update weights with the chosen samples
weights_update = np.bincount(indices, minlength=n_samples) weights_update = np.bincount(indices, minlength=n_samples)
features = self.get_subspace(X, y)
current_weights = weights * weights_update current_weights = weights * weights_update
# train the classifier # train the classifier
clf.fit(X[indices, :], y[indices], current_weights[indices]) clf.fit(X[indices, features], y[indices], current_weights[indices])
def _get_bootstrap_n_samples(self, n_samples) -> int: def _get_bootstrap_n_samples(self, n_samples) -> int:
if self.max_samples is None: if self.max_samples is None:
@ -137,15 +140,69 @@ class Odte(BaseEstimator, ClassifierMixin):
{type(self.max_samples)}" {type(self.max_samples)}"
) )
def _initialize_max_features(self) -> int:
if isinstance(self.max_features, str):
if self.max_features == "auto":
max_features = max(1, int(np.sqrt(self.n_features_)))
elif self.max_features == "sqrt":
max_features = max(1, int(np.sqrt(self.n_features_)))
elif self.max_features == "log2":
max_features = max(1, int(np.log2(self.n_features_)))
else:
raise ValueError(
"Invalid value for max_features. "
"Allowed string values are 'auto', "
"'sqrt' or 'log2'."
)
elif self.max_features is None:
max_features = self.n_features_
elif isinstance(self.max_features, int):
max_features = self.max_features
else: # float
if self.max_features > 0.0:
max_features = max(
1, int(self.max_features * self.n_features_)
)
else:
raise ValueError(
"Invalid value for max_features."
"Allowed float must be in range (0, 1] "
f"got ({self.max_features})"
)
return max_features
def _get_subspaces_set(
self, dataset: np.array, labels: np.array
) -> np.array:
features = range(dataset.shape[1])
features_sets = list(combinations(features, self.max_features_))
if len(features_sets) > 1:
index = random.randint(0, len(features_sets) - 1)
return features_sets[index]
else:
return features_sets[0]
def get_subspace(self, dataset: np.array, labels: np.array) -> list:
"""Return the best subspace to build a tree
"""
indices = self._get_subspaces_set(dataset, labels)
return dataset[:, indices], indices
def predict(self, X: np.array) -> np.array: def predict(self, X: np.array) -> np.array:
# todo proba = self.predict_proba(X)
return self.classes_.take((np.argmax(proba, axis=1)), axis=0)
def predict_proba(self, X: np.array) -> np.array:
check_is_fitted(self, ["estimators_"]) check_is_fitted(self, ["estimators_"])
# Input validation # Input validation
X = check_array(X) X = check_array(X)
result = np.empty((X.shape[0], self.n_estimators)) for tree in self.estimators_:
for index, tree in enumerate(self.estimators_): n_samples = X.shape[0]
result[:, index] = tree.predict(X) result = np.zeros((n_samples, self.n_classes_))
return mode(result, axis=1).mode.ravel() predictions = tree.predict(X)
for i in range(n_samples):
result[i, predictions[i]] += 1
return result
def score( def score(
self, X: np.array, y: np.array, sample_weight: np.array = None self, X: np.array, y: np.array, sample_weight: np.array = None

View File

@ -62,9 +62,13 @@ class Odte_test(unittest.TestCase):
warnings.filterwarnings("ignore", category=ConvergenceWarning) warnings.filterwarnings("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
X, y = [[1, 2], [5, 6], [9, 10], [16, 17]], [0, 1, 1, 2] X, y = [[1, 2], [5, 6], [9, 10], [16, 17]], [0, 1, 1, 2]
expected = [0, 1, 1, 0] expected = [1, 1, 1, 1]
tclf = Odte( tclf = Odte(random_state=self._random_state, n_estimators=10,)
random_state=self._random_state, n_estimators=10, kernel="rbf" tclf.set_params(
**dict(
base_estimator__kernel="rbf",
base_estimator__random_state=self._random_state,
)
) )
computed = tclf.fit(X, y).predict(X) computed = tclf.fit(X, y).predict(X)
self.assertListEqual(expected, computed.tolist()) self.assertListEqual(expected, computed.tolist())
@ -77,32 +81,47 @@ class Odte_test(unittest.TestCase):
tclf = Odte( tclf = Odte(
random_state=self._random_state, random_state=self._random_state,
max_features=None, max_features=None,
kernel="linear",
max_samples=0.1, max_samples=0.1,
) )
tclf.set_params(**dict(base_estimator__kernel="linear",))
computed = tclf.fit(X, y).predict(X) computed = tclf.fit(X, y).predict(X)
self.assertListEqual(expected[:27].tolist(), computed[:27].tolist()) self.assertListEqual(expected[:27].tolist(), computed[:27].tolist())
def test_score(self): def test_score(self):
X, y = load_dataset(self._random_state) X, y = load_dataset(self._random_state)
expected = 0.9526666666666667 expected = 0.948
tclf = Odte( tclf = Odte(
random_state=self._random_state, max_features=None, n_estimators=10 random_state=self._random_state,
max_features=None,
n_estimators=10,
) )
computed = tclf.fit(X, y).score(X, y) computed = tclf.fit(X, y).score(X, y)
self.assertAlmostEqual(expected, computed) self.assertAlmostEqual(expected, computed)
def test_score_splitter_max_features(self): def test_score_splitter_max_features(self):
X, y = load_dataset(self._random_state, n_features=12, n_samples=150) X, y = load_dataset(self._random_state, n_features=12, n_samples=150)
results = [1.0, 0.94, 0.9933333333333333, 0.9933333333333333] results = [
0.9866666666666667,
0.9866666666666667,
0.9866666666666667,
0.9866666666666667,
]
for max_features in ["auto", None]: for max_features in ["auto", None]:
for splitter in ["best", "random"]: for splitter in ["best", "random"]:
tclf = Odte( tclf = Odte(
random_state=self._random_state, random_state=self._random_state,
splitter=splitter,
max_features=max_features, max_features=max_features,
n_estimators=10, n_estimators=10,
) )
tclf.set_params(**dict(base_estimator__splitter=splitter,))
expected = results.pop(0) expected = results.pop(0)
computed = tclf.fit(X, y).score(X, y) computed = tclf.fit(X, y).score(X, y)
self.assertAlmostEqual(expected, computed) self.assertAlmostEqual(expected, computed)
@staticmethod
def test_is_a_sklearn_classifier():
warnings.filterwarnings("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
from sklearn.utils.estimator_checks import check_estimator
check_estimator(Odte())