Fix predict and predict_proba

Add static types
Fix tests
This commit is contained in:
Ricardo Montañana Gómez 2020-07-06 00:12:46 +02:00
parent e1bfa9f9bf
commit b17582e93a
Signed by: rmontanana
GPG Key ID: 46064262FD9A7ADE
4 changed files with 119 additions and 104 deletions

File diff suppressed because one or more lines are too long

View File

@ -55,7 +55,7 @@
{
"output_type": "stream",
"name": "stdout",
"text": "****************************** Results for wine ******************************\nTraining stree...\nScore: 94.444 in 0.17 seconds\nTraining odte...\nScore: 97.222 in 2.70 seconds\nTraining adaboost...\nScore: 94.444 in 0.60 seconds\nTraining bagging...\nScore: 100.000 in 2.55 seconds\n"
"text": "****************************** Results for wine ******************************\nTraining stree...\nScore: 94.444 in 0.19 seconds\nTraining odte...\nScore: 100.000 in 3.43 seconds\nTraining adaboost...\nScore: 94.444 in 0.76 seconds\nTraining bagging...\nScore: 100.000 in 3.27 seconds\n"
}
],
"source": [
@ -102,7 +102,7 @@
{
"output_type": "stream",
"name": "stdout",
"text": "****************************** Results for iris ******************************\nTraining stree...\nScore: 100.000 in 0.02 seconds\nTraining odte...\nScore: 93.333 in 0.12 seconds\nTraining adaboost...\nScore: 83.333 in 0.01 seconds\nTraining bagging...\nScore: 100.000 in 0.11 seconds\n"
"text": "****************************** Results for iris ******************************\nTraining stree...\nScore: 100.000 in 0.02 seconds\nTraining odte...\nScore: 100.000 in 0.15 seconds\nTraining adaboost...\nScore: 83.333 in 0.01 seconds\nTraining bagging...\nScore: 96.667 in 0.13 seconds\n"
}
],
"source": [
@ -124,7 +124,7 @@
{
"output_type": "stream",
"name": "stdout",
"text": "{'fit_time': array([0.15752316, 0.18354201, 0.14742589, 0.13827896, 0.14534211]), 'score_time': array([0.00940681, 0.01064587, 0.01085019, 0.00925183, 0.00878191]), 'test_score': array([0.8 , 0.93333333, 0.93333333, 0.93333333, 0.96666667]), 'train_score': array([0.875 , 0.95 , 0.98333333, 0.98333333, 0.95833333])}\n91.333 +- 0.058\n"
"text": "{'fit_time': array([0.23599219, 0.22772503, 0.21689606, 0.20017815, 0.22257805]), 'score_time': array([0.01378369, 0.01322389, 0.0125649 , 0.01751685, 0.01062703]), 'test_score': array([1. , 1. , 1. , 0.93333333, 1. ]), 'train_score': array([0.98333333, 0.96666667, 0.99166667, 0.99166667, 0.975 ])}\n98.667 +- 0.027\n"
}
],
"source": [
@ -143,7 +143,7 @@
{
"output_type": "stream",
"name": "stdout",
"text": "{'fit_time': array([0.01752877, 0.03304005, 0.03542018, 0.03398919, 0.03945518]), 'score_time': array([0.00135112, 0.00164104, 0.00159597, 0.0018959 , 0.00189495]), 'test_score': array([1. , 0.93333333, 0.93333333, 0.93333333, 0.96666667]), 'train_score': array([0.93333333, 0.96666667, 0.96666667, 0.96666667, 0.95 ])}\n95.333 +- 0.027\n"
"text": "{'fit_time': array([0.02912688, 0.05858397, 0.06724691, 0.02860498, 0.03802919]), 'score_time': array([0.0024271 , 0.0022819 , 0.00219584, 0.00195408, 0.00342584]), 'test_score': array([1. , 0.93333333, 0.93333333, 0.93333333, 0.96666667]), 'train_score': array([0.93333333, 0.96666667, 0.96666667, 0.96666667, 0.95 ])}\n95.333 +- 0.027\n"
}
],
"source": [
@ -151,6 +151,30 @@
"print(cross)\n",
"print(f\"{np.mean(cross['test_score'])*100:.3f} +- {np.std(cross['test_score']):.3f}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "1 functools.partial(<function check_no_attributes_set_in_init at 0x12b593290>, 'Odte')\n2 functools.partial(<function check_estimators_dtypes at 0x12b58d3b0>, 'Odte')\n3 functools.partial(<function check_fit_score_takes_y at 0x12b58d290>, 'Odte')\n4 functools.partial(<function check_sample_weights_pandas_series at 0x12b586b90>, 'Odte')\n5 functools.partial(<function check_sample_weights_not_an_array at 0x12b586cb0>, 'Odte')\n6 functools.partial(<function check_sample_weights_list at 0x12b586dd0>, 'Odte')\n7 functools.partial(<function check_sample_weights_shape at 0x12b586ef0>, 'Odte')\n8 functools.partial(<function check_sample_weights_invariance at 0x12b58a050>, 'Odte')\n9 functools.partial(<function check_estimators_fit_returns_self at 0x12b5913b0>, 'Odte')\n10 functools.partial(<function check_estimators_fit_returns_self at 0x12b5913b0>, 'Odte', readonly_memmap=True)\n11 functools.partial(<function check_complex_data at 0x12b58a200>, 'Odte')\n12 functools.partial(<function check_dtype_object at 0x12b58a170>, 'Odte')\n13 functools.partial(<function check_estimators_empty_data_messages at 0x12b58d4d0>, 'Odte')\n14 functools.partial(<function check_pipeline_consistency at 0x12b58d170>, 'Odte')\n15 functools.partial(<function check_estimators_nan_inf at 0x12b58d5f0>, 'Odte')\n16 functools.partial(<function check_estimators_overwrite_params at 0x12b593170>, 'Odte')\n17 functools.partial(<function check_estimator_sparse_data at 0x12b586a70>, 'Odte')\n18 functools.partial(<function check_estimators_pickle at 0x12b58d830>, 'Odte')\n19 functools.partial(<function check_classifier_data_not_an_array at 0x12b5934d0>, 'Odte')\n20 functools.partial(<function check_classifiers_one_label at 0x12b58def0>, 'Odte')\n21 functools.partial(<function check_classifiers_classes at 0x12b591950>, 'Odte')\n22 functools.partial(<function check_estimators_partial_fit_n_features at 0x12b58d950>, 'Odte')\n23 functools.partial(<function check_classifiers_train at 0x12b591050>, 'Odte')\n24 functools.partial(<function check_classifiers_train at 0x12b591050>, 'Odte', readonly_memmap=True)\n25 functools.partial(<function check_classifiers_train at 0x12b591050>, 'Odte', readonly_memmap=True, X_dtype='float32')\n26 functools.partial(<function check_classifiers_regression_target at 0x12b593f80>, 'Odte')\n27 functools.partial(<function check_supervised_y_no_nan at 0x12b57eb90>, 'Odte')\n28 functools.partial(<function check_supervised_y_2d at 0x12b5915f0>, 'Odte')\n29 functools.partial(<function check_estimators_unfitted at 0x12b5914d0>, 'Odte')\n30 functools.partial(<function check_non_transformer_estimators_n_iter at 0x12b593b00>, 'Odte')\n31 functools.partial(<function check_decision_proba_consistency at 0x12b5970e0>, 'Odte')\n32 functools.partial(<function check_fit2d_predict1d at 0x12b58a710>, 'Odte')\n33 functools.partial(<function check_methods_subset_invariance at 0x12b58a8c0>, 'Odte')\n34 functools.partial(<function check_fit2d_1sample at 0x12b58a9e0>, 'Odte')\n35 functools.partial(<function check_fit2d_1feature at 0x12b58ab00>, 'Odte')\n36 functools.partial(<function check_fit1d at 0x12b58ac20>, 'Odte')\n37 functools.partial(<function check_get_params_invariance at 0x12b593d40>, 'Odte')\n38 functools.partial(<function check_set_params at 0x12b593e60>, 'Odte')\n39 functools.partial(<function check_dict_unchanged at 0x12b58a320>, 'Odte')\n40 functools.partial(<function check_dont_overwrite_parameters at 0x12b58a5f0>, 'Odte')\n41 functools.partial(<function check_fit_idempotent at 0x12b597290>, 'Odte')\n42 functools.partial(<function check_n_features_in at 0x12b597320>, 'Odte')\n"
}
],
"source": [
"from sklearn.utils.estimator_checks import check_estimator\n",
"# Make checks one by one\n",
"c = 0\n",
"checks = check_estimator(Odte(), generate_only=True)\n",
"for check in checks:\n",
" c += 1\n",
" print(c, check[1])\n",
" check[1](check[0])"
]
}
],
"metadata": {

View File

@ -5,33 +5,32 @@ __license__ = "MIT"
__version__ = "0.1"
Build a forest of oblique trees based on STree
"""
from __future__ import annotations
import random
from typing import Union
import sys
from typing import Union, Optional, Tuple, List
from itertools import combinations
import numpy as np
from sklearn.utils import check_consistent_length
from sklearn.metrics._classification import _weighted_sum, _check_targets
from sklearn.utils.multiclass import check_classification_targets
from sklearn.base import clone, ClassifierMixin
from sklearn.ensemble import BaseEnsemble
from sklearn.utils.validation import (
check_X_y,
check_array,
import numpy as np # type: ignore
from sklearn.utils.multiclass import ( # type: ignore
check_classification_targets,
)
from sklearn.base import clone, BaseEstimator, ClassifierMixin # type: ignore
from sklearn.ensemble import BaseEnsemble # type: ignore
from sklearn.utils.validation import ( # type: ignore
check_is_fitted,
_check_sample_weight,
)
from stree import Stree
from stree import Stree # type: ignore
class Odte(BaseEnsemble, ClassifierMixin):
class Odte(BaseEnsemble, ClassifierMixin): # type: ignore
def __init__(
self,
base_estimator=None,
random_state: int = None,
max_features: Union[str, int, float] = 1.0,
max_samples: Union[int, float] = None,
base_estimator: BaseEstimator = None,
random_state: int = 0,
max_features: Optional[Union[str, int, float]] = None,
max_samples: Optional[Union[int, float]] = None,
n_estimators: int = 100,
):
base_estimator = (
@ -47,11 +46,9 @@ class Odte(BaseEnsemble, ClassifierMixin):
self.max_features = max_features
self.max_samples = max_samples # size of bootstrap
def _more_tags(self) -> dict:
return {"requires_y": True}
def _initialize_random(self) -> np.random.mtrand.RandomState:
if self.random_state is None:
self.random_state = random.randint(0, sys.maxint)
return np.random.mtrand._rand
return np.random.RandomState(self.random_state)
@ -63,7 +60,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
return np.ones((n_samples,), dtype=np.float64)
return sample_weight.copy()
def _validate_estimator(self):
def _validate_estimator(self) -> None:
"""Check the estimator and set the base_estimator_ attribute."""
super()._validate_estimator(
default=Stree(random_state=self.random_state)
@ -71,7 +68,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
def fit(
self, X: np.array, y: np.array, sample_weight: np.array = None
) -> "Odte":
) -> Odte:
# Check parameters are Ok.
if self.n_estimators < 3:
raise ValueError(
@ -79,34 +76,36 @@ class Odte(BaseEnsemble, ClassifierMixin):
{self.n_estimators})"
)
check_classification_targets(y)
X, y = check_X_y(X, y)
X, y = self._validate_data(X, y)
sample_weight = _check_sample_weight(
sample_weight, X, dtype=np.float64
)
check_classification_targets(y)
# Initialize computed parameters
# Build the estimator
self.n_features_in_ = X.shape[1]
self.n_features_ = X.shape[1]
self.max_features_ = self._initialize_max_features()
# build base_estimator_
self._validate_estimator()
self.classes_, y = np.unique(y, return_inverse=True)
self.n_classes_ = self.classes_.shape[0]
self.estimators_ = []
self.subspaces_ = []
self.n_classes_: int = self.classes_.shape[0]
self.estimators_: List[BaseEstimator] = []
self.subspaces_: List[Tuple[int, ...]] = []
self._train(X, y, sample_weight)
return self
def _train(
self, X: np.array, y: np.array, sample_weight: np.array
) -> "Odte":
) -> None:
random_box = self._initialize_random()
random_seed = self.random_state
n_samples = X.shape[0]
weights = self._initialize_sample_weight(sample_weight, n_samples)
boot_samples = self._get_bootstrap_n_samples(n_samples)
for _ in range(self.n_estimators):
# Build clf
clf = clone(self.base_estimator_)
clf.random_state = random_seed
random_seed += 1
self.estimators_.append(clf)
# bootstrap
indices = random_box.randint(0, n_samples, boot_samples)
@ -121,7 +120,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
bootstrap[:, features], y[indices], current_weights[indices]
)
def _get_bootstrap_n_samples(self, n_samples) -> int:
def _get_bootstrap_n_samples(self, n_samples: int) -> int:
if self.max_samples is None:
return n_samples
if isinstance(self.max_samples, int):
@ -144,11 +143,11 @@ class Odte(BaseEnsemble, ClassifierMixin):
def _initialize_max_features(self) -> int:
if isinstance(self.max_features, str):
if self.max_features == "auto":
max_features = max(1, int(np.sqrt(self.n_features_)))
max_features = max(1, int(np.sqrt(self.n_features_in_)))
elif self.max_features == "sqrt":
max_features = max(1, int(np.sqrt(self.n_features_)))
max_features = max(1, int(np.sqrt(self.n_features_in_)))
elif self.max_features == "log2":
max_features = max(1, int(np.log2(self.n_features_)))
max_features = max(1, int(np.log2(self.n_features_in_)))
else:
raise ValueError(
"Invalid value for max_features. "
@ -156,13 +155,13 @@ class Odte(BaseEnsemble, ClassifierMixin):
"'sqrt' or 'log2'."
)
elif self.max_features is None:
max_features = self.n_features_
max_features = self.n_features_in_
elif isinstance(self.max_features, int):
max_features = abs(self.max_features)
else: # float
if self.max_features > 0.0:
max_features = max(
1, int(self.max_features * self.n_features_)
1, int(self.max_features * self.n_features_in_)
)
else:
raise ValueError(
@ -174,7 +173,7 @@ class Odte(BaseEnsemble, ClassifierMixin):
def _get_random_subspace(
self, dataset: np.array, labels: np.array
) -> np.array:
) -> Tuple[int, ...]:
features = range(dataset.shape[1])
features_sets = list(combinations(features, self.max_features_))
if len(features_sets) > 1:
@ -185,35 +184,16 @@ class Odte(BaseEnsemble, ClassifierMixin):
def predict(self, X: np.array) -> np.array:
proba = self.predict_proba(X)
return self.classes_.take((np.argmax(proba, axis=1)), axis=0)
return self.classes_[np.argmax(proba, axis=1)]
def predict_proba(self, X: np.array) -> np.array:
check_is_fitted(self, ["estimators_"])
check_is_fitted(self, "estimators_")
# Input validation
X = check_array(X)
if self.n_features_ != X.shape[1]:
raise ValueError(
"Number of features of the model must "
"match the input. Model n_features is {0} and "
"input n_features is {1}."
"".format(self.n_features_, X.shape[1])
)
X = self._validate_data(X, reset=False)
n_samples = X.shape[0]
result = np.zeros((n_samples, self.n_classes_))
for tree, features in zip(self.estimators_, self.subspaces_):
n_samples = X.shape[0]
result = np.zeros((n_samples, self.n_classes_))
predictions = tree.predict(X[:, features])
for i in range(n_samples):
result[i, predictions[i]] += 1
return result
def score(
self, X: np.array, y: np.array, sample_weight: np.array = None
) -> float:
check_classification_targets(y)
X, y = check_X_y(X, y)
y_pred = self.predict(X).reshape(y.shape)
# Compute accuracy for each possible representation
_, y_true, y_pred = _check_targets(y, y_pred)
check_consistent_length(y_true, y_pred, sample_weight)
score = y_true == y_pred
return _weighted_sum(score, sample_weight, normalize=True)
return result / self.n_estimators

View File

@ -39,13 +39,13 @@ class Odte_test(unittest.TestCase):
def test_initialize_max_feature(self):
expected_values = [
[0, 4, 10, 11],
[0, 2, 3, 5, 14, 15],
[0, 5, 6, 15],
[0, 2, 3, 9, 11, 14],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[0, 4, 10, 11],
[0, 4, 10, 11],
[0, 4, 10, 11],
[0, 5, 6, 15],
[0, 5, 6, 15],
[0, 5, 6, 15],
]
X, y = load_dataset(
random_state=self._random_state, n_features=16, n_samples=10
@ -91,7 +91,7 @@ class Odte_test(unittest.TestCase):
warnings.filterwarnings("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
X, y = [[1, 2], [5, 6], [9, 10], [16, 17]], [0, 1, 1, 2]
expected = [1, 1, 1, 1]
expected = [0, 1, 1, 1]
tclf = Odte(random_state=self._random_state, n_estimators=10,)
tclf.set_params(
**dict(
@ -116,7 +116,7 @@ class Odte_test(unittest.TestCase):
def test_score(self):
X, y = load_dataset(self._random_state)
expected = 0.948
expected = 0.9526666666666667
tclf = Odte(
random_state=self._random_state,
max_features=None,
@ -128,10 +128,10 @@ class Odte_test(unittest.TestCase):
def test_score_splitter_max_features(self):
X, y = load_dataset(self._random_state, n_features=12, n_samples=150)
results = [
0.6466666666666666,
0.6466666666666666,
0.9866666666666667,
0.9866666666666667,
1.0,
1.0,
0.9933333333333333,
0.9933333333333333,
]
for max_features in ["auto", None]:
for splitter in ["best", "random"]: