From b17582e93a70a8cbcfa51d5c23136d66a461d837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 6 Jul 2020 00:12:46 +0200 Subject: [PATCH] Fix predict and predict_proba Add static types Fix tests --- notebooks/benchmark.ipynb | 67 ++++++++++++++----------- notebooks/wine_iris.ipynb | 32 ++++++++++-- odte/Odte.py | 102 +++++++++++++++----------------------- odte/tests/Odte_tests.py | 22 ++++---- 4 files changed, 119 insertions(+), 104 deletions(-) diff --git a/notebooks/benchmark.ipynb b/notebooks/benchmark.ipynb index 9503e8d..0063a25 100644 --- a/notebooks/benchmark.ipynb +++ b/notebooks/benchmark.ipynb @@ -17,7 +17,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -47,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -66,13 +66,15 @@ }, { "cell_type": "code", - "execution_count": 22, - "metadata": {}, + "execution_count": 4, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "2020-06-15 11:44:45\n" + "text": "2020-07-04 21:56:25\n" } ], "source": [ @@ -88,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -100,8 +102,10 @@ }, { "cell_type": "code", - "execution_count": 24, - "metadata": {}, + "execution_count": 6, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", @@ -116,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -128,8 +132,10 @@ }, { "cell_type": "code", - "execution_count": 26, - "metadata": {}, + "execution_count": 8, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", @@ -153,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -164,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -174,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -184,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -194,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -204,7 +210,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -214,12 +220,12 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Oblique Decision Tree Ensemble\n", - "odte = Odte(random_state=random_state, n_estimators=10, max_features=\"auto\")" + "odte = Odte(random_state=random_state, max_features=\"auto\")" ] }, { @@ -231,7 +237,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -256,20 +262,22 @@ }, { "cell_type": "code", - "execution_count": 35, - "metadata": {}, + "execution_count": 17, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "************************** Linear Tree **********************\nTrain Model Linear Tree took: 14.78 seconds\n=========== Linear Tree - Train 199,364 samples =============\n precision recall f1-score support\n\n 0 1.000000 1.000000 1.000000 199020\n 1 1.000000 1.000000 1.000000 344\n\n accuracy 1.000000 199364\n macro avg 1.000000 1.000000 1.000000 199364\nweighted avg 1.000000 1.000000 1.000000 199364\n\n=========== Linear Tree - Test 85,443 samples =============\n precision recall f1-score support\n\n 0 0.999578 0.999613 0.999596 85295\n 1 0.772414 0.756757 0.764505 148\n\n accuracy 0.999192 85443\n macro avg 0.885996 0.878185 0.882050 85443\nweighted avg 0.999184 0.999192 0.999188 85443\n\nConfusion Matrix in Train\n[[199020 0]\n [ 0 344]]\nConfusion Matrix in Test\n[[85262 33]\n [ 36 112]]\n************************** Random Forest **********************\nTrain Model Random Forest took: 163.9 seconds\n=========== Random Forest - Train 199,364 samples =============\n precision recall f1-score support\n\n 0 1.000000 1.000000 1.000000 199020\n 1 1.000000 1.000000 1.000000 344\n\n accuracy 1.000000 199364\n macro avg 1.000000 1.000000 1.000000 199364\nweighted avg 1.000000 1.000000 1.000000 199364\n\n=========== Random Forest - Test 85,443 samples =============\n precision recall f1-score support\n\n 0 0.999660 0.999965 0.999812 85295\n 1 0.975410 0.804054 0.881481 148\n\n accuracy 0.999625 85443\n macro avg 0.987535 0.902009 0.940647 85443\nweighted avg 0.999618 0.999625 0.999607 85443\n\nConfusion Matrix in Train\n[[199020 0]\n [ 0 344]]\nConfusion Matrix in Test\n[[85292 3]\n [ 29 119]]\n************************** Stree (SVM Tree) **********************\nTrain Model Stree (SVM Tree) took: 34.57 seconds\n=========== Stree (SVM Tree) - Train 199,364 samples =============\n precision recall f1-score support\n\n 0 0.999623 0.999864 0.999744 199020\n 1 0.908784 0.781977 0.840625 344\n\n accuracy 0.999488 199364\n macro avg 0.954204 0.890921 0.920184 199364\nweighted avg 0.999467 0.999488 0.999469 199364\n\n=========== Stree (SVM Tree) - Test 85,443 samples =============\n precision recall f1-score support\n\n 0 0.999637 0.999918 0.999777 85295\n 1 0.943548 0.790541 0.860294 148\n\n accuracy 0.999555 85443\n macro avg 0.971593 0.895229 0.930036 85443\nweighted avg 0.999540 0.999555 0.999536 85443\n\nConfusion Matrix in Train\n[[198993 27]\n [ 75 269]]\nConfusion Matrix in Test\n[[85288 7]\n [ 31 117]]\n************************** AdaBoost model **********************\nTrain Model AdaBoost model took: 44.36 seconds\n=========== AdaBoost model - Train 199,364 samples =============\n precision recall f1-score support\n\n 0 0.999392 0.999678 0.999535 199020\n 1 0.777003 0.648256 0.706815 344\n\n accuracy 0.999072 199364\n macro avg 0.888198 0.823967 0.853175 199364\nweighted avg 0.999008 0.999072 0.999030 199364\n\n=========== AdaBoost model - Test 85,443 samples =============\n precision recall f1-score support\n\n 0 0.999484 0.999707 0.999596 85295\n 1 0.806202 0.702703 0.750903 148\n\n accuracy 0.999192 85443\n macro avg 0.902843 0.851205 0.875249 85443\nweighted avg 0.999149 0.999192 0.999165 85443\n\nConfusion Matrix in Train\n[[198956 64]\n [ 121 223]]\nConfusion Matrix in Test\n[[85270 25]\n [ 44 104]]\n************************** Odte **********************\nTrain Model Odte took: 2.134e+03 seconds\n=========== Odte - Train 199,364 samples =============\n precision recall f1-score support\n\n 0 0.999583 1.000000 0.999792 199020\n 1 1.000000 0.758721 0.862810 344\n\n accuracy 0.999584 199364\n macro avg 0.999792 0.879360 0.931301 199364\nweighted avg 0.999584 0.999584 0.999555 199364\n\n=========== Odte - Test 85,443 samples =============\n precision recall f1-score support\n\n 0 0.999543 0.999965 0.999754 85295\n 1 0.973214 0.736486 0.838462 148\n\n accuracy 0.999508 85443\n macro avg 0.986379 0.868226 0.919108 85443\nweighted avg 0.999497 0.999508 0.999474 85443\n\nConfusion Matrix in Train\n[[199020 0]\n [ 83 261]]\nConfusion Matrix in Test\n[[85292 3]\n [ 39 109]]\n" + "text": "************************** Linear Tree **********************\nTrain Model Linear Tree took: 14.81 seconds\n=========== Linear Tree - Train 199,364 samples =============\n precision recall f1-score support\n\n 0 1.000000 1.000000 1.000000 199020\n 1 1.000000 1.000000 1.000000 344\n\n accuracy 1.000000 199364\n macro avg 1.000000 1.000000 1.000000 199364\nweighted avg 1.000000 1.000000 1.000000 199364\n\n=========== Linear Tree - Test 85,443 samples =============\n precision recall f1-score support\n\n 0 0.999578 0.999613 0.999596 85295\n 1 0.772414 0.756757 0.764505 148\n\n accuracy 0.999192 85443\n macro avg 0.885996 0.878185 0.882050 85443\nweighted avg 0.999184 0.999192 0.999188 85443\n\nConfusion Matrix in Train\n[[199020 0]\n [ 0 344]]\nConfusion Matrix in Test\n[[85262 33]\n [ 36 112]]\n************************** Random Forest **********************\nTrain Model Random Forest took: 172.6 seconds\n=========== Random Forest - Train 199,364 samples =============\n precision recall f1-score support\n\n 0 1.000000 1.000000 1.000000 199020\n 1 1.000000 1.000000 1.000000 344\n\n accuracy 1.000000 199364\n macro avg 1.000000 1.000000 1.000000 199364\nweighted avg 1.000000 1.000000 1.000000 199364\n\n=========== Random Forest - Test 85,443 samples =============\n precision recall f1-score support\n\n 0 0.999660 0.999965 0.999812 85295\n 1 0.975410 0.804054 0.881481 148\n\n accuracy 0.999625 85443\n macro avg 0.987535 0.902009 0.940647 85443\nweighted avg 0.999618 0.999625 0.999607 85443\n\nConfusion Matrix in Train\n[[199020 0]\n [ 0 344]]\nConfusion Matrix in Test\n[[85292 3]\n [ 29 119]]\n************************** Stree (SVM Tree) **********************\nTrain Model Stree (SVM Tree) took: 39.26 seconds\n=========== Stree (SVM Tree) - Train 199,364 samples =============\n precision recall f1-score support\n\n 0 0.999623 0.999864 0.999744 199020\n 1 0.908784 0.781977 0.840625 344\n\n accuracy 0.999488 199364\n macro avg 0.954204 0.890921 0.920184 199364\nweighted avg 0.999467 0.999488 0.999469 199364\n\n=========== Stree (SVM Tree) - Test 85,443 samples =============\n precision recall f1-score support\n\n 0 0.999637 0.999918 0.999777 85295\n 1 0.943548 0.790541 0.860294 148\n\n accuracy 0.999555 85443\n macro avg 0.971593 0.895229 0.930036 85443\nweighted avg 0.999540 0.999555 0.999536 85443\n\nConfusion Matrix in Train\n[[198993 27]\n [ 75 269]]\nConfusion Matrix in Test\n[[85288 7]\n [ 31 117]]\n************************** AdaBoost model **********************\nTrain Model AdaBoost model took: 49.55 seconds\n=========== AdaBoost model - Train 199,364 samples =============\n precision recall f1-score support\n\n 0 0.999392 0.999678 0.999535 199020\n 1 0.777003 0.648256 0.706815 344\n\n accuracy 0.999072 199364\n macro avg 0.888198 0.823967 0.853175 199364\nweighted avg 0.999008 0.999072 0.999030 199364\n\n=========== AdaBoost model - Test 85,443 samples =============\n precision recall f1-score support\n\n 0 0.999484 0.999707 0.999596 85295\n 1 0.806202 0.702703 0.750903 148\n\n accuracy 0.999192 85443\n macro avg 0.902843 0.851205 0.875249 85443\nweighted avg 0.999149 0.999192 0.999165 85443\n\nConfusion Matrix in Train\n[[198956 64]\n [ 121 223]]\nConfusion Matrix in Test\n[[85270 25]\n [ 44 104]]\n************************** Odte model **********************\nTrain Model Odte model took: 5.758e+03 seconds\n=========== Odte model - Train 199,364 samples =============\n precision recall f1-score support\n\n 0 0.998725 0.999990 0.999357 199020\n 1 0.978261 0.261628 0.412844 344\n\n accuracy 0.998716 199364\n macro avg 0.988493 0.630809 0.706101 199364\nweighted avg 0.998690 0.998716 0.998345 199364\n\n=========== Odte model - Test 85,443 samples =============\n precision recall f1-score support\n\n 0 0.998794 0.999988 0.999391 85295\n 1 0.978261 0.304054 0.463918 148\n\n accuracy 0.998783 85443\n macro avg 0.988527 0.652021 0.731654 85443\nweighted avg 0.998758 0.998783 0.998463 85443\n\nConfusion Matrix in Train\n[[199018 2]\n [ 254 90]]\nConfusion Matrix in Test\n[[85294 1]\n [ 103 45]]\n" } ], "source": [ "# Train & Test models\n", "models = {\n", " 'Linear Tree':linear_tree, 'Random Forest': random_forest, 'Stree (SVM Tree)': stree, \n", - " 'AdaBoost model': adaboost, 'Odte': odte #'Gradient Boost.': gradient\n", + " 'AdaBoost model': adaboost, 'Odte model': odte #'Gradient Boost.': gradient\n", "}\n", "\n", "best_f1 = 0\n", @@ -285,13 +293,15 @@ }, { "cell_type": "code", - "execution_count": 36, - "metadata": {}, + "execution_count": 18, + "metadata": { + "tags": [] + }, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "**************************************************************************************************************\n*The best f1 model is Random Forest, with a f1 score: 0.8815 in 163.896 seconds with 0.7 samples in train dataset\n**************************************************************************************************************\nModel: Linear Tree\t Time: 14.78 seconds\t f1: 0.7645\nModel: Random Forest\t Time: 163.90 seconds\t f1: 0.8815\nModel: Stree (SVM Tree)\t Time: 34.57 seconds\t f1: 0.8603\nModel: AdaBoost model\t Time: 44.36 seconds\t f1: 0.7509\nModel: Odte\t Time: 2134.25 seconds\t f1: 0.8385\n" + "text": "**************************************************************************************************************\n*The best f1 model is Random Forest, with a f1 score: 0.8815 in 172.611 seconds with 0.7 samples in train dataset\n**************************************************************************************************************\nModel: Linear Tree\t Time: 14.81 seconds\t f1: 0.7645\nModel: Random Forest\t Time: 172.61 seconds\t f1: 0.8815\nModel: Stree (SVM Tree)\t Time: 39.26 seconds\t f1: 0.8603\nModel: AdaBoost model\t Time: 49.55 seconds\t f1: 0.7509\nModel: Odte model\t Time: 5758.26 seconds\t f1: 0.4639\n" } ], "source": [ @@ -330,6 +340,7 @@ "Model: AdaBoost model\t Time: 73.83 seconds\t f1: 0.7509\n", "Model: Gradient Boost.\t Time: 388.69 seconds\t f1: 0.5259\n", "Model: Neural Network\t Time: 25.47 seconds\t f1: 0.8328\n", + "Model: Odte \t Time:2134.25 seconds\t f1: 0.8385\n", "```" ] } diff --git a/notebooks/wine_iris.ipynb b/notebooks/wine_iris.ipynb index ec1d342..7736895 100644 --- a/notebooks/wine_iris.ipynb +++ b/notebooks/wine_iris.ipynb @@ -55,7 +55,7 @@ { "output_type": "stream", "name": "stdout", - "text": "****************************** Results for wine ******************************\nTraining stree...\nScore: 94.444 in 0.17 seconds\nTraining odte...\nScore: 97.222 in 2.70 seconds\nTraining adaboost...\nScore: 94.444 in 0.60 seconds\nTraining bagging...\nScore: 100.000 in 2.55 seconds\n" + "text": "****************************** Results for wine ******************************\nTraining stree...\nScore: 94.444 in 0.19 seconds\nTraining odte...\nScore: 100.000 in 3.43 seconds\nTraining adaboost...\nScore: 94.444 in 0.76 seconds\nTraining bagging...\nScore: 100.000 in 3.27 seconds\n" } ], "source": [ @@ -102,7 +102,7 @@ { "output_type": "stream", "name": "stdout", - "text": "****************************** Results for iris ******************************\nTraining stree...\nScore: 100.000 in 0.02 seconds\nTraining odte...\nScore: 93.333 in 0.12 seconds\nTraining adaboost...\nScore: 83.333 in 0.01 seconds\nTraining bagging...\nScore: 100.000 in 0.11 seconds\n" + "text": "****************************** Results for iris ******************************\nTraining stree...\nScore: 100.000 in 0.02 seconds\nTraining odte...\nScore: 100.000 in 0.15 seconds\nTraining adaboost...\nScore: 83.333 in 0.01 seconds\nTraining bagging...\nScore: 96.667 in 0.13 seconds\n" } ], "source": [ @@ -124,7 +124,7 @@ { "output_type": "stream", "name": "stdout", - "text": "{'fit_time': array([0.15752316, 0.18354201, 0.14742589, 0.13827896, 0.14534211]), 'score_time': array([0.00940681, 0.01064587, 0.01085019, 0.00925183, 0.00878191]), 'test_score': array([0.8 , 0.93333333, 0.93333333, 0.93333333, 0.96666667]), 'train_score': array([0.875 , 0.95 , 0.98333333, 0.98333333, 0.95833333])}\n91.333 +- 0.058\n" + "text": "{'fit_time': array([0.23599219, 0.22772503, 0.21689606, 0.20017815, 0.22257805]), 'score_time': array([0.01378369, 0.01322389, 0.0125649 , 0.01751685, 0.01062703]), 'test_score': array([1. , 1. , 1. , 0.93333333, 1. ]), 'train_score': array([0.98333333, 0.96666667, 0.99166667, 0.99166667, 0.975 ])}\n98.667 +- 0.027\n" } ], "source": [ @@ -143,7 +143,7 @@ { "output_type": "stream", "name": "stdout", - "text": "{'fit_time': array([0.01752877, 0.03304005, 0.03542018, 0.03398919, 0.03945518]), 'score_time': array([0.00135112, 0.00164104, 0.00159597, 0.0018959 , 0.00189495]), 'test_score': array([1. , 0.93333333, 0.93333333, 0.93333333, 0.96666667]), 'train_score': array([0.93333333, 0.96666667, 0.96666667, 0.96666667, 0.95 ])}\n95.333 +- 0.027\n" + "text": "{'fit_time': array([0.02912688, 0.05858397, 0.06724691, 0.02860498, 0.03802919]), 'score_time': array([0.0024271 , 0.0022819 , 0.00219584, 0.00195408, 0.00342584]), 'test_score': array([1. , 0.93333333, 0.93333333, 0.93333333, 0.96666667]), 'train_score': array([0.93333333, 0.96666667, 0.96666667, 0.96666667, 0.95 ])}\n95.333 +- 0.027\n" } ], "source": [ @@ -151,6 +151,30 @@ "print(cross)\n", "print(f\"{np.mean(cross['test_score'])*100:.3f} +- {np.std(cross['test_score']):.3f}\")" ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "1 functools.partial(, 'Odte')\n2 functools.partial(, 'Odte')\n3 functools.partial(, 'Odte')\n4 functools.partial(, 'Odte')\n5 functools.partial(, 'Odte')\n6 functools.partial(, 'Odte')\n7 functools.partial(, 'Odte')\n8 functools.partial(, 'Odte')\n9 functools.partial(, 'Odte')\n10 functools.partial(, 'Odte', readonly_memmap=True)\n11 functools.partial(, 'Odte')\n12 functools.partial(, 'Odte')\n13 functools.partial(, 'Odte')\n14 functools.partial(, 'Odte')\n15 functools.partial(, 'Odte')\n16 functools.partial(, 'Odte')\n17 functools.partial(, 'Odte')\n18 functools.partial(, 'Odte')\n19 functools.partial(, 'Odte')\n20 functools.partial(, 'Odte')\n21 functools.partial(, 'Odte')\n22 functools.partial(, 'Odte')\n23 functools.partial(, 'Odte')\n24 functools.partial(, 'Odte', readonly_memmap=True)\n25 functools.partial(, 'Odte', readonly_memmap=True, X_dtype='float32')\n26 functools.partial(, 'Odte')\n27 functools.partial(, 'Odte')\n28 functools.partial(, 'Odte')\n29 functools.partial(, 'Odte')\n30 functools.partial(, 'Odte')\n31 functools.partial(, 'Odte')\n32 functools.partial(, 'Odte')\n33 functools.partial(, 'Odte')\n34 functools.partial(, 'Odte')\n35 functools.partial(, 'Odte')\n36 functools.partial(, 'Odte')\n37 functools.partial(, 'Odte')\n38 functools.partial(, 'Odte')\n39 functools.partial(, 'Odte')\n40 functools.partial(, 'Odte')\n41 functools.partial(, 'Odte')\n42 functools.partial(, 'Odte')\n" + } + ], + "source": [ + "from sklearn.utils.estimator_checks import check_estimator\n", + "# Make checks one by one\n", + "c = 0\n", + "checks = check_estimator(Odte(), generate_only=True)\n", + "for check in checks:\n", + " c += 1\n", + " print(c, check[1])\n", + " check[1](check[0])" + ] } ], "metadata": { diff --git a/odte/Odte.py b/odte/Odte.py index 538aeec..464d626 100644 --- a/odte/Odte.py +++ b/odte/Odte.py @@ -5,33 +5,32 @@ __license__ = "MIT" __version__ = "0.1" Build a forest of oblique trees based on STree """ - +from __future__ import annotations import random -from typing import Union +import sys +from typing import Union, Optional, Tuple, List from itertools import combinations -import numpy as np -from sklearn.utils import check_consistent_length -from sklearn.metrics._classification import _weighted_sum, _check_targets -from sklearn.utils.multiclass import check_classification_targets -from sklearn.base import clone, ClassifierMixin -from sklearn.ensemble import BaseEnsemble -from sklearn.utils.validation import ( - check_X_y, - check_array, +import numpy as np # type: ignore +from sklearn.utils.multiclass import ( # type: ignore + check_classification_targets, +) +from sklearn.base import clone, BaseEstimator, ClassifierMixin # type: ignore +from sklearn.ensemble import BaseEnsemble # type: ignore +from sklearn.utils.validation import ( # type: ignore check_is_fitted, _check_sample_weight, ) -from stree import Stree +from stree import Stree # type: ignore -class Odte(BaseEnsemble, ClassifierMixin): +class Odte(BaseEnsemble, ClassifierMixin): # type: ignore def __init__( self, - base_estimator=None, - random_state: int = None, - max_features: Union[str, int, float] = 1.0, - max_samples: Union[int, float] = None, + base_estimator: BaseEstimator = None, + random_state: int = 0, + max_features: Optional[Union[str, int, float]] = None, + max_samples: Optional[Union[int, float]] = None, n_estimators: int = 100, ): base_estimator = ( @@ -47,11 +46,9 @@ class Odte(BaseEnsemble, ClassifierMixin): self.max_features = max_features self.max_samples = max_samples # size of bootstrap - def _more_tags(self) -> dict: - return {"requires_y": True} - def _initialize_random(self) -> np.random.mtrand.RandomState: if self.random_state is None: + self.random_state = random.randint(0, sys.maxint) return np.random.mtrand._rand return np.random.RandomState(self.random_state) @@ -63,7 +60,7 @@ class Odte(BaseEnsemble, ClassifierMixin): return np.ones((n_samples,), dtype=np.float64) return sample_weight.copy() - def _validate_estimator(self): + def _validate_estimator(self) -> None: """Check the estimator and set the base_estimator_ attribute.""" super()._validate_estimator( default=Stree(random_state=self.random_state) @@ -71,7 +68,7 @@ class Odte(BaseEnsemble, ClassifierMixin): def fit( self, X: np.array, y: np.array, sample_weight: np.array = None - ) -> "Odte": + ) -> Odte: # Check parameters are Ok. if self.n_estimators < 3: raise ValueError( @@ -79,34 +76,36 @@ class Odte(BaseEnsemble, ClassifierMixin): {self.n_estimators})" ) check_classification_targets(y) - X, y = check_X_y(X, y) + X, y = self._validate_data(X, y) sample_weight = _check_sample_weight( sample_weight, X, dtype=np.float64 ) check_classification_targets(y) # Initialize computed parameters # Build the estimator - self.n_features_in_ = X.shape[1] - self.n_features_ = X.shape[1] self.max_features_ = self._initialize_max_features() + # build base_estimator_ self._validate_estimator() self.classes_, y = np.unique(y, return_inverse=True) - self.n_classes_ = self.classes_.shape[0] - self.estimators_ = [] - self.subspaces_ = [] + self.n_classes_: int = self.classes_.shape[0] + self.estimators_: List[BaseEstimator] = [] + self.subspaces_: List[Tuple[int, ...]] = [] self._train(X, y, sample_weight) return self def _train( self, X: np.array, y: np.array, sample_weight: np.array - ) -> "Odte": + ) -> None: random_box = self._initialize_random() + random_seed = self.random_state n_samples = X.shape[0] weights = self._initialize_sample_weight(sample_weight, n_samples) boot_samples = self._get_bootstrap_n_samples(n_samples) for _ in range(self.n_estimators): # Build clf clf = clone(self.base_estimator_) + clf.random_state = random_seed + random_seed += 1 self.estimators_.append(clf) # bootstrap indices = random_box.randint(0, n_samples, boot_samples) @@ -121,7 +120,7 @@ class Odte(BaseEnsemble, ClassifierMixin): bootstrap[:, features], y[indices], current_weights[indices] ) - def _get_bootstrap_n_samples(self, n_samples) -> int: + def _get_bootstrap_n_samples(self, n_samples: int) -> int: if self.max_samples is None: return n_samples if isinstance(self.max_samples, int): @@ -144,11 +143,11 @@ class Odte(BaseEnsemble, ClassifierMixin): def _initialize_max_features(self) -> int: if isinstance(self.max_features, str): if self.max_features == "auto": - max_features = max(1, int(np.sqrt(self.n_features_))) + max_features = max(1, int(np.sqrt(self.n_features_in_))) elif self.max_features == "sqrt": - max_features = max(1, int(np.sqrt(self.n_features_))) + max_features = max(1, int(np.sqrt(self.n_features_in_))) elif self.max_features == "log2": - max_features = max(1, int(np.log2(self.n_features_))) + max_features = max(1, int(np.log2(self.n_features_in_))) else: raise ValueError( "Invalid value for max_features. " @@ -156,13 +155,13 @@ class Odte(BaseEnsemble, ClassifierMixin): "'sqrt' or 'log2'." ) elif self.max_features is None: - max_features = self.n_features_ + max_features = self.n_features_in_ elif isinstance(self.max_features, int): max_features = abs(self.max_features) else: # float if self.max_features > 0.0: max_features = max( - 1, int(self.max_features * self.n_features_) + 1, int(self.max_features * self.n_features_in_) ) else: raise ValueError( @@ -174,7 +173,7 @@ class Odte(BaseEnsemble, ClassifierMixin): def _get_random_subspace( self, dataset: np.array, labels: np.array - ) -> np.array: + ) -> Tuple[int, ...]: features = range(dataset.shape[1]) features_sets = list(combinations(features, self.max_features_)) if len(features_sets) > 1: @@ -185,35 +184,16 @@ class Odte(BaseEnsemble, ClassifierMixin): def predict(self, X: np.array) -> np.array: proba = self.predict_proba(X) - return self.classes_.take((np.argmax(proba, axis=1)), axis=0) + return self.classes_[np.argmax(proba, axis=1)] def predict_proba(self, X: np.array) -> np.array: - check_is_fitted(self, ["estimators_"]) + check_is_fitted(self, "estimators_") # Input validation - X = check_array(X) - if self.n_features_ != X.shape[1]: - raise ValueError( - "Number of features of the model must " - "match the input. Model n_features is {0} and " - "input n_features is {1}." - "".format(self.n_features_, X.shape[1]) - ) + X = self._validate_data(X, reset=False) + n_samples = X.shape[0] + result = np.zeros((n_samples, self.n_classes_)) for tree, features in zip(self.estimators_, self.subspaces_): - n_samples = X.shape[0] - result = np.zeros((n_samples, self.n_classes_)) predictions = tree.predict(X[:, features]) for i in range(n_samples): result[i, predictions[i]] += 1 - return result - - def score( - self, X: np.array, y: np.array, sample_weight: np.array = None - ) -> float: - check_classification_targets(y) - X, y = check_X_y(X, y) - y_pred = self.predict(X).reshape(y.shape) - # Compute accuracy for each possible representation - _, y_true, y_pred = _check_targets(y, y_pred) - check_consistent_length(y_true, y_pred, sample_weight) - score = y_true == y_pred - return _weighted_sum(score, sample_weight, normalize=True) + return result / self.n_estimators diff --git a/odte/tests/Odte_tests.py b/odte/tests/Odte_tests.py index 210a457..28b7c70 100644 --- a/odte/tests/Odte_tests.py +++ b/odte/tests/Odte_tests.py @@ -39,13 +39,13 @@ class Odte_test(unittest.TestCase): def test_initialize_max_feature(self): expected_values = [ - [0, 4, 10, 11], - [0, 2, 3, 5, 14, 15], + [0, 5, 6, 15], + [0, 2, 3, 9, 11, 14], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - [0, 4, 10, 11], - [0, 4, 10, 11], - [0, 4, 10, 11], + [0, 5, 6, 15], + [0, 5, 6, 15], + [0, 5, 6, 15], ] X, y = load_dataset( random_state=self._random_state, n_features=16, n_samples=10 @@ -91,7 +91,7 @@ class Odte_test(unittest.TestCase): warnings.filterwarnings("ignore", category=ConvergenceWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) X, y = [[1, 2], [5, 6], [9, 10], [16, 17]], [0, 1, 1, 2] - expected = [1, 1, 1, 1] + expected = [0, 1, 1, 1] tclf = Odte(random_state=self._random_state, n_estimators=10,) tclf.set_params( **dict( @@ -116,7 +116,7 @@ class Odte_test(unittest.TestCase): def test_score(self): X, y = load_dataset(self._random_state) - expected = 0.948 + expected = 0.9526666666666667 tclf = Odte( random_state=self._random_state, max_features=None, @@ -128,10 +128,10 @@ class Odte_test(unittest.TestCase): def test_score_splitter_max_features(self): X, y = load_dataset(self._random_state, n_features=12, n_samples=150) results = [ - 0.6466666666666666, - 0.6466666666666666, - 0.9866666666666667, - 0.9866666666666667, + 1.0, + 1.0, + 0.9933333333333333, + 0.9933333333333333, ] for max_features in ["auto", None]: for splitter in ["best", "random"]: