Fix predict and predict_proba

Add static types
Fix tests
This commit is contained in:
Ricardo Montañana Gómez 2020-07-06 00:12:46 +02:00
parent e1bfa9f9bf
commit b17582e93a
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
4 changed files with 119 additions and 104 deletions

File diff suppressed because one or more lines are too long

View File

@ -55,7 +55,7 @@
{ {
"output_type": "stream", "output_type": "stream",
"name": "stdout", "name": "stdout",
"text": "****************************** Results for wine ******************************\nTraining stree...\nScore: 94.444 in 0.17 seconds\nTraining odte...\nScore: 97.222 in 2.70 seconds\nTraining adaboost...\nScore: 94.444 in 0.60 seconds\nTraining bagging...\nScore: 100.000 in 2.55 seconds\n" "text": "****************************** Results for wine ******************************\nTraining stree...\nScore: 94.444 in 0.19 seconds\nTraining odte...\nScore: 100.000 in 3.43 seconds\nTraining adaboost...\nScore: 94.444 in 0.76 seconds\nTraining bagging...\nScore: 100.000 in 3.27 seconds\n"
} }
], ],
"source": [ "source": [
@ -102,7 +102,7 @@
{ {
"output_type": "stream", "output_type": "stream",
"name": "stdout", "name": "stdout",
"text": "****************************** Results for iris ******************************\nTraining stree...\nScore: 100.000 in 0.02 seconds\nTraining odte...\nScore: 93.333 in 0.12 seconds\nTraining adaboost...\nScore: 83.333 in 0.01 seconds\nTraining bagging...\nScore: 100.000 in 0.11 seconds\n" "text": "****************************** Results for iris ******************************\nTraining stree...\nScore: 100.000 in 0.02 seconds\nTraining odte...\nScore: 100.000 in 0.15 seconds\nTraining adaboost...\nScore: 83.333 in 0.01 seconds\nTraining bagging...\nScore: 96.667 in 0.13 seconds\n"
} }
], ],
"source": [ "source": [
@ -124,7 +124,7 @@
{ {
"output_type": "stream", "output_type": "stream",
"name": "stdout", "name": "stdout",
"text": "{'fit_time': array([0.15752316, 0.18354201, 0.14742589, 0.13827896, 0.14534211]), 'score_time': array([0.00940681, 0.01064587, 0.01085019, 0.00925183, 0.00878191]), 'test_score': array([0.8 , 0.93333333, 0.93333333, 0.93333333, 0.96666667]), 'train_score': array([0.875 , 0.95 , 0.98333333, 0.98333333, 0.95833333])}\n91.333 +- 0.058\n" "text": "{'fit_time': array([0.23599219, 0.22772503, 0.21689606, 0.20017815, 0.22257805]), 'score_time': array([0.01378369, 0.01322389, 0.0125649 , 0.01751685, 0.01062703]), 'test_score': array([1. , 1. , 1. , 0.93333333, 1. ]), 'train_score': array([0.98333333, 0.96666667, 0.99166667, 0.99166667, 0.975 ])}\n98.667 +- 0.027\n"
} }
], ],
"source": [ "source": [
@ -143,7 +143,7 @@
{ {
"output_type": "stream", "output_type": "stream",
"name": "stdout", "name": "stdout",
"text": "{'fit_time': array([0.01752877, 0.03304005, 0.03542018, 0.03398919, 0.03945518]), 'score_time': array([0.00135112, 0.00164104, 0.00159597, 0.0018959 , 0.00189495]), 'test_score': array([1. , 0.93333333, 0.93333333, 0.93333333, 0.96666667]), 'train_score': array([0.93333333, 0.96666667, 0.96666667, 0.96666667, 0.95 ])}\n95.333 +- 0.027\n" "text": "{'fit_time': array([0.02912688, 0.05858397, 0.06724691, 0.02860498, 0.03802919]), 'score_time': array([0.0024271 , 0.0022819 , 0.00219584, 0.00195408, 0.00342584]), 'test_score': array([1. , 0.93333333, 0.93333333, 0.93333333, 0.96666667]), 'train_score': array([0.93333333, 0.96666667, 0.96666667, 0.96666667, 0.95 ])}\n95.333 +- 0.027\n"
} }
], ],
"source": [ "source": [
@ -151,6 +151,30 @@
"print(cross)\n", "print(cross)\n",
"print(f\"{np.mean(cross['test_score'])*100:.3f} +- {np.std(cross['test_score']):.3f}\")" "print(f\"{np.mean(cross['test_score'])*100:.3f} +- {np.std(cross['test_score']):.3f}\")"
] ]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "1 functools.partial(<function check_no_attributes_set_in_init at 0x12b593290>, 'Odte')\n2 functools.partial(<function check_estimators_dtypes at 0x12b58d3b0>, 'Odte')\n3 functools.partial(<function check_fit_score_takes_y at 0x12b58d290>, 'Odte')\n4 functools.partial(<function check_sample_weights_pandas_series at 0x12b586b90>, 'Odte')\n5 functools.partial(<function check_sample_weights_not_an_array at 0x12b586cb0>, 'Odte')\n6 functools.partial(<function check_sample_weights_list at 0x12b586dd0>, 'Odte')\n7 functools.partial(<function check_sample_weights_shape at 0x12b586ef0>, 'Odte')\n8 functools.partial(<function check_sample_weights_invariance at 0x12b58a050>, 'Odte')\n9 functools.partial(<function check_estimators_fit_returns_self at 0x12b5913b0>, 'Odte')\n10 functools.partial(<function check_estimators_fit_returns_self at 0x12b5913b0>, 'Odte', readonly_memmap=True)\n11 functools.partial(<function check_complex_data at 0x12b58a200>, 'Odte')\n12 functools.partial(<function check_dtype_object at 0x12b58a170>, 'Odte')\n13 functools.partial(<function check_estimators_empty_data_messages at 0x12b58d4d0>, 'Odte')\n14 functools.partial(<function check_pipeline_consistency at 0x12b58d170>, 'Odte')\n15 functools.partial(<function check_estimators_nan_inf at 0x12b58d5f0>, 'Odte')\n16 functools.partial(<function check_estimators_overwrite_params at 0x12b593170>, 'Odte')\n17 functools.partial(<function check_estimator_sparse_data at 0x12b586a70>, 'Odte')\n18 functools.partial(<function check_estimators_pickle at 0x12b58d830>, 'Odte')\n19 functools.partial(<function check_classifier_data_not_an_array at 0x12b5934d0>, 'Odte')\n20 functools.partial(<function check_classifiers_one_label at 0x12b58def0>, 'Odte')\n21 functools.partial(<function check_classifiers_classes at 0x12b591950>, 'Odte')\n22 functools.partial(<function check_estimators_partial_fit_n_features at 0x12b58d950>, 'Odte')\n23 functools.partial(<function check_classifiers_train at 0x12b591050>, 'Odte')\n24 functools.partial(<function check_classifiers_train at 0x12b591050>, 'Odte', readonly_memmap=True)\n25 functools.partial(<function check_classifiers_train at 0x12b591050>, 'Odte', readonly_memmap=True, X_dtype='float32')\n26 functools.partial(<function check_classifiers_regression_target at 0x12b593f80>, 'Odte')\n27 functools.partial(<function check_supervised_y_no_nan at 0x12b57eb90>, 'Odte')\n28 functools.partial(<function check_supervised_y_2d at 0x12b5915f0>, 'Odte')\n29 functools.partial(<function check_estimators_unfitted at 0x12b5914d0>, 'Odte')\n30 functools.partial(<function check_non_transformer_estimators_n_iter at 0x12b593b00>, 'Odte')\n31 functools.partial(<function check_decision_proba_consistency at 0x12b5970e0>, 'Odte')\n32 functools.partial(<function check_fit2d_predict1d at 0x12b58a710>, 'Odte')\n33 functools.partial(<function check_methods_subset_invariance at 0x12b58a8c0>, 'Odte')\n34 functools.partial(<function check_fit2d_1sample at 0x12b58a9e0>, 'Odte')\n35 functools.partial(<function check_fit2d_1feature at 0x12b58ab00>, 'Odte')\n36 functools.partial(<function check_fit1d at 0x12b58ac20>, 'Odte')\n37 functools.partial(<function check_get_params_invariance at 0x12b593d40>, 'Odte')\n38 functools.partial(<function check_set_params at 0x12b593e60>, 'Odte')\n39 functools.partial(<function check_dict_unchanged at 0x12b58a320>, 'Odte')\n40 functools.partial(<function check_dont_overwrite_parameters at 0x12b58a5f0>, 'Odte')\n41 functools.partial(<function check_fit_idempotent at 0x12b597290>, 'Odte')\n42 functools.partial(<function check_n_features_in at 0x12b597320>, 'Odte')\n"
}
],
"source": [
"from sklearn.utils.estimator_checks import check_estimator\n",
"# Make checks one by one\n",
"c = 0\n",
"checks = check_estimator(Odte(), generate_only=True)\n",
"for check in checks:\n",
" c += 1\n",
" print(c, check[1])\n",
" check[1](check[0])"
]
} }
], ],
"metadata": { "metadata": {

View File

@ -5,33 +5,32 @@ __license__ = "MIT"
__version__ = "0.1" __version__ = "0.1"
Build a forest of oblique trees based on STree Build a forest of oblique trees based on STree
""" """
from __future__ import annotations
import random import random
from typing import Union import sys
from typing import Union, Optional, Tuple, List
from itertools import combinations from itertools import combinations
import numpy as np import numpy as np # type: ignore
from sklearn.utils import check_consistent_length from sklearn.utils.multiclass import ( # type: ignore
from sklearn.metrics._classification import _weighted_sum, _check_targets check_classification_targets,
from sklearn.utils.multiclass import check_classification_targets )
from sklearn.base import clone, ClassifierMixin from sklearn.base import clone, BaseEstimator, ClassifierMixin # type: ignore
from sklearn.ensemble import BaseEnsemble from sklearn.ensemble import BaseEnsemble # type: ignore
from sklearn.utils.validation import ( from sklearn.utils.validation import ( # type: ignore
check_X_y,
check_array,
check_is_fitted, check_is_fitted,
_check_sample_weight, _check_sample_weight,
) )
from stree import Stree from stree import Stree # type: ignore
class Odte(BaseEnsemble, ClassifierMixin): class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
def __init__( def __init__(
self, self,
base_estimator=None, base_estimator: BaseEstimator = None,
random_state: int = None, random_state: int = 0,
max_features: Union[str, int, float] = 1.0, max_features: Optional[Union[str, int, float]] = None,
max_samples: Union[int, float] = None, max_samples: Optional[Union[int, float]] = None,
n_estimators: int = 100, n_estimators: int = 100,
): ):
base_estimator = ( base_estimator = (
@ -47,11 +46,9 @@ class Odte(BaseEnsemble, ClassifierMixin):
self.max_features = max_features self.max_features = max_features
self.max_samples = max_samples # size of bootstrap self.max_samples = max_samples # size of bootstrap
def _more_tags(self) -> dict:
return {"requires_y": True}
def _initialize_random(self) -> np.random.mtrand.RandomState: def _initialize_random(self) -> np.random.mtrand.RandomState:
if self.random_state is None: if self.random_state is None:
self.random_state = random.randint(0, sys.maxint)
return np.random.mtrand._rand return np.random.mtrand._rand
return np.random.RandomState(self.random_state) return np.random.RandomState(self.random_state)
@ -63,7 +60,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
return np.ones((n_samples,), dtype=np.float64) return np.ones((n_samples,), dtype=np.float64)
return sample_weight.copy() return sample_weight.copy()
def _validate_estimator(self): def _validate_estimator(self) -> None:
"""Check the estimator and set the base_estimator_ attribute.""" """Check the estimator and set the base_estimator_ attribute."""
super()._validate_estimator( super()._validate_estimator(
default=Stree(random_state=self.random_state) default=Stree(random_state=self.random_state)
@ -71,7 +68,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
def fit( def fit(
self, X: np.array, y: np.array, sample_weight: np.array = None self, X: np.array, y: np.array, sample_weight: np.array = None
) -> "Odte": ) -> Odte:
# Check parameters are Ok. # Check parameters are Ok.
if self.n_estimators < 3: if self.n_estimators < 3:
raise ValueError( raise ValueError(
@ -79,34 +76,36 @@ class Odte(BaseEnsemble, ClassifierMixin):
{self.n_estimators})" {self.n_estimators})"
) )
check_classification_targets(y) check_classification_targets(y)
X, y = check_X_y(X, y) X, y = self._validate_data(X, y)
sample_weight = _check_sample_weight( sample_weight = _check_sample_weight(
sample_weight, X, dtype=np.float64 sample_weight, X, dtype=np.float64
) )
check_classification_targets(y) check_classification_targets(y)
# Initialize computed parameters # Initialize computed parameters
# Build the estimator # Build the estimator
self.n_features_in_ = X.shape[1]
self.n_features_ = X.shape[1]
self.max_features_ = self._initialize_max_features() self.max_features_ = self._initialize_max_features()
# build base_estimator_
self._validate_estimator() self._validate_estimator()
self.classes_, y = np.unique(y, return_inverse=True) self.classes_, y = np.unique(y, return_inverse=True)
self.n_classes_ = self.classes_.shape[0] self.n_classes_: int = self.classes_.shape[0]
self.estimators_ = [] self.estimators_: List[BaseEstimator] = []
self.subspaces_ = [] self.subspaces_: List[Tuple[int, ...]] = []
self._train(X, y, sample_weight) self._train(X, y, sample_weight)
return self return self
def _train( def _train(
self, X: np.array, y: np.array, sample_weight: np.array self, X: np.array, y: np.array, sample_weight: np.array
) -> "Odte": ) -> None:
random_box = self._initialize_random() random_box = self._initialize_random()
random_seed = self.random_state
n_samples = X.shape[0] n_samples = X.shape[0]
weights = self._initialize_sample_weight(sample_weight, n_samples) weights = self._initialize_sample_weight(sample_weight, n_samples)
boot_samples = self._get_bootstrap_n_samples(n_samples) boot_samples = self._get_bootstrap_n_samples(n_samples)
for _ in range(self.n_estimators): for _ in range(self.n_estimators):
# Build clf # Build clf
clf = clone(self.base_estimator_) clf = clone(self.base_estimator_)
clf.random_state = random_seed
random_seed += 1
self.estimators_.append(clf) self.estimators_.append(clf)
# bootstrap # bootstrap
indices = random_box.randint(0, n_samples, boot_samples) indices = random_box.randint(0, n_samples, boot_samples)
@ -121,7 +120,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
bootstrap[:, features], y[indices], current_weights[indices] bootstrap[:, features], y[indices], current_weights[indices]
) )
def _get_bootstrap_n_samples(self, n_samples) -> int: def _get_bootstrap_n_samples(self, n_samples: int) -> int:
if self.max_samples is None: if self.max_samples is None:
return n_samples return n_samples
if isinstance(self.max_samples, int): if isinstance(self.max_samples, int):
@ -144,11 +143,11 @@ class Odte(BaseEnsemble, ClassifierMixin):
def _initialize_max_features(self) -> int: def _initialize_max_features(self) -> int:
if isinstance(self.max_features, str): if isinstance(self.max_features, str):
if self.max_features == "auto": if self.max_features == "auto":
max_features = max(1, int(np.sqrt(self.n_features_))) max_features = max(1, int(np.sqrt(self.n_features_in_)))
elif self.max_features == "sqrt": elif self.max_features == "sqrt":
max_features = max(1, int(np.sqrt(self.n_features_))) max_features = max(1, int(np.sqrt(self.n_features_in_)))
elif self.max_features == "log2": elif self.max_features == "log2":
max_features = max(1, int(np.log2(self.n_features_))) max_features = max(1, int(np.log2(self.n_features_in_)))
else: else:
raise ValueError( raise ValueError(
"Invalid value for max_features. " "Invalid value for max_features. "
@ -156,13 +155,13 @@ class Odte(BaseEnsemble, ClassifierMixin):
"'sqrt' or 'log2'." "'sqrt' or 'log2'."
) )
elif self.max_features is None: elif self.max_features is None:
max_features = self.n_features_ max_features = self.n_features_in_
elif isinstance(self.max_features, int): elif isinstance(self.max_features, int):
max_features = abs(self.max_features) max_features = abs(self.max_features)
else: # float else: # float
if self.max_features > 0.0: if self.max_features > 0.0:
max_features = max( max_features = max(
1, int(self.max_features * self.n_features_) 1, int(self.max_features * self.n_features_in_)
) )
else: else:
raise ValueError( raise ValueError(
@ -174,7 +173,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
def _get_random_subspace( def _get_random_subspace(
self, dataset: np.array, labels: np.array self, dataset: np.array, labels: np.array
) -> np.array: ) -> Tuple[int, ...]:
features = range(dataset.shape[1]) features = range(dataset.shape[1])
features_sets = list(combinations(features, self.max_features_)) features_sets = list(combinations(features, self.max_features_))
if len(features_sets) > 1: if len(features_sets) > 1:
@ -185,35 +184,16 @@ class Odte(BaseEnsemble, ClassifierMixin):
def predict(self, X: np.array) -> np.array: def predict(self, X: np.array) -> np.array:
proba = self.predict_proba(X) proba = self.predict_proba(X)
return self.classes_.take((np.argmax(proba, axis=1)), axis=0) return self.classes_[np.argmax(proba, axis=1)]
def predict_proba(self, X: np.array) -> np.array: def predict_proba(self, X: np.array) -> np.array:
check_is_fitted(self, ["estimators_"]) check_is_fitted(self, "estimators_")
# Input validation # Input validation
X = check_array(X) X = self._validate_data(X, reset=False)
if self.n_features_ != X.shape[1]: n_samples = X.shape[0]
raise ValueError( result = np.zeros((n_samples, self.n_classes_))
"Number of features of the model must "
"match the input. Model n_features is {0} and "
"input n_features is {1}."
"".format(self.n_features_, X.shape[1])
)
for tree, features in zip(self.estimators_, self.subspaces_): for tree, features in zip(self.estimators_, self.subspaces_):
n_samples = X.shape[0]
result = np.zeros((n_samples, self.n_classes_))
predictions = tree.predict(X[:, features]) predictions = tree.predict(X[:, features])
for i in range(n_samples): for i in range(n_samples):
result[i, predictions[i]] += 1 result[i, predictions[i]] += 1
return result return result / self.n_estimators
def score(
self, X: np.array, y: np.array, sample_weight: np.array = None
) -> float:
check_classification_targets(y)
X, y = check_X_y(X, y)
y_pred = self.predict(X).reshape(y.shape)
# Compute accuracy for each possible representation
_, y_true, y_pred = _check_targets(y, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
score = y_true == y_pred
return _weighted_sum(score, sample_weight, normalize=True)

View File

@ -39,13 +39,13 @@ class Odte_test(unittest.TestCase):
def test_initialize_max_feature(self): def test_initialize_max_feature(self):
expected_values = [ expected_values = [
[0, 4, 10, 11], [0, 5, 6, 15],
[0, 2, 3, 5, 14, 15], [0, 2, 3, 9, 11, 14],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[0, 4, 10, 11], [0, 5, 6, 15],
[0, 4, 10, 11], [0, 5, 6, 15],
[0, 4, 10, 11], [0, 5, 6, 15],
] ]
X, y = load_dataset( X, y = load_dataset(
random_state=self._random_state, n_features=16, n_samples=10 random_state=self._random_state, n_features=16, n_samples=10
@ -91,7 +91,7 @@ class Odte_test(unittest.TestCase):
warnings.filterwarnings("ignore", category=ConvergenceWarning) warnings.filterwarnings("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
X, y = [[1, 2], [5, 6], [9, 10], [16, 17]], [0, 1, 1, 2] X, y = [[1, 2], [5, 6], [9, 10], [16, 17]], [0, 1, 1, 2]
expected = [1, 1, 1, 1] expected = [0, 1, 1, 1]
tclf = Odte(random_state=self._random_state, n_estimators=10,) tclf = Odte(random_state=self._random_state, n_estimators=10,)
tclf.set_params( tclf.set_params(
**dict( **dict(
@ -116,7 +116,7 @@ class Odte_test(unittest.TestCase):
def test_score(self): def test_score(self):
X, y = load_dataset(self._random_state) X, y = load_dataset(self._random_state)
expected = 0.948 expected = 0.9526666666666667
tclf = Odte( tclf = Odte(
random_state=self._random_state, random_state=self._random_state,
max_features=None, max_features=None,
@ -128,10 +128,10 @@ class Odte_test(unittest.TestCase):
def test_score_splitter_max_features(self): def test_score_splitter_max_features(self):
X, y = load_dataset(self._random_state, n_features=12, n_samples=150) X, y = load_dataset(self._random_state, n_features=12, n_samples=150)
results = [ results = [
0.6466666666666666, 1.0,
0.6466666666666666, 1.0,
0.9866666666666667, 0.9933333333333333,
0.9866666666666667, 0.9933333333333333,
] ]
for max_features in ["auto", None]: for max_features in ["auto", None]:
for splitter in ["best", "random"]: for splitter in ["best", "random"]: