From 502ee72799c98dfe7b8588f2f3fb98305b74ab34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sun, 14 Jun 2020 14:00:21 +0200 Subject: [PATCH] #2 Add predict and score support Add a test in features notebook Show max_features in main.py --- main.py | 9 +++++++ notebooks/features.ipynb | 55 ++++++++++++++++++++++++++++++++------- stree/Strees.py | 18 +++++++++---- stree/tests/Stree_test.py | 14 ++++++++++ 4 files changed, 82 insertions(+), 14 deletions(-) diff --git a/main.py b/main.py index e4722c7..7b40929 100644 --- a/main.py +++ b/main.py @@ -12,6 +12,15 @@ Xtrain, Xtest, ytrain, ytest = train_test_split( ) now = time.time() +print("Predicting with max_features=sqrt(n_features)") +clf = Stree(C=0.01, random_state=random_state, max_features="auto") +clf.fit(Xtrain, ytrain) +print(f"Took {time.time() - now:.2f} seconds to train") +print(clf) +print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}") +print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}") +print("=" * 40) +print("Predicting with max_features=n_features") clf = Stree(C=0.01, random_state=random_state) clf.fit(Xtrain, ytrain) print(f"Took {time.time() - now:.2f} seconds to train") diff --git a/notebooks/features.ipynb b/notebooks/features.ipynb index 9eda9b0..c7d0611 100644 --- a/notebooks/features.ipynb +++ b/notebooks/features.ipynb @@ -64,7 +64,7 @@ { "output_type": "stream", "name": "stdout", - "text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 33.110% 494\nValid: 66.890% 998\n" + "text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 33.244% 496\nValid: 66.756% 996\n" } ], "source": [ @@ -135,7 +135,7 @@ { "output_type": "stream", "name": "stdout", - "text": "Accuracy of Train without weights 0.9789272030651341\nAccuracy of Train with weights 0.9952107279693486\nAccuracy of Tests without weights 0.9598214285714286\nAccuracy of Tests with weights 0.9508928571428571\n" + "text": "Accuracy of Train without weights 0.9808429118773946\nAccuracy of Train with weights 0.9904214559386973\nAccuracy of Tests without weights 0.9441964285714286\nAccuracy of Tests with weights 0.9375\n" } ], "source": [ @@ -162,7 +162,7 @@ { "output_type": "stream", "name": "stdout", - "text": "Time: 0.27s\tKernel: linear\tAccuracy_train: 0.9683908045977011\tAccuracy_test: 0.953125\nTime: 0.09s\tKernel: rbf\tAccuracy_train: 0.9875478927203065\tAccuracy_test: 0.9598214285714286\nTime: 0.06s\tKernel: poly\tAccuracy_train: 0.9885057471264368\tAccuracy_test: 0.9464285714285714\n" + "text": "Time: 0.13s\tKernel: linear\tAccuracy_train: 0.9693486590038314\tAccuracy_test: 0.9598214285714286\nTime: 0.09s\tKernel: rbf\tAccuracy_train: 0.9923371647509579\tAccuracy_test: 0.953125\nTime: 0.09s\tKernel: poly\tAccuracy_train: 0.9913793103448276\tAccuracy_test: 0.9375\n" } ], "source": [ @@ -195,7 +195,7 @@ { "output_type": "stream", "name": "stdout", - "text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9531\nClassifier's accuracy (test) : 0.9621\nroot\nroot - Down, - Leaf class=1 belief= 0.983713 counts=(array([0, 1]), array([ 5, 302]))\nroot - Up, - Leaf class=0 belief= 0.940299 counts=(array([0, 1]), array([693, 44]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9569\nClassifier's accuracy (test) : 0.9621\nroot\nroot - Down, - Leaf class=1 belief= 0.990228 counts=(array([0, 1]), array([ 3, 304]))\nroot - Up, - Leaf class=0 belief= 0.943012 counts=(array([0, 1]), array([695, 42]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9655\nClassifier's accuracy (test) : 0.9643\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([310]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([5]))\nroot - Up, - Leaf class=0 belief= 0.950617 counts=(array([0, 1]), array([693, 36]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9684\nClassifier's accuracy (test) : 0.9598\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([311]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([8]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([2]))\nroot - Up - Up\nroot - Up - Up - Down, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up, - Leaf class=0 belief= 0.954039 counts=(array([0, 1]), array([685, 33]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9751\nClassifier's accuracy (test) : 0.9464\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([304]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([8]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([4]))\nroot - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([3]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Up - Up - Up - Up, - Leaf class=0 belief= 0.963225 counts=(array([0, 1]), array([681, 26]))\n\n**************************************************\n0.6869 secs\n" + "text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9588\nClassifier's accuracy (test) : 0.9487\nroot feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4438\nroot - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0374\nroot - Down - Down, - Leaf class=1 belief= 0.984076 impurity=0.0313 counts=(array([0, 1]), array([ 5, 309]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([1]))\nroot - Up, - Leaf class=0 belief= 0.947874 impurity=0.0988 counts=(array([0, 1]), array([691, 38]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9588\nClassifier's accuracy (test) : 0.9531\nroot feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4438\nroot - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0192\nroot - Down - Down, - Leaf class=1 belief= 0.993506 impurity=0.0129 counts=(array([0, 1]), array([ 2, 306]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([1]))\nroot - Up, - Leaf class=0 belief= 0.944218 impurity=0.1053 counts=(array([0, 1]), array([694, 41]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9665\nClassifier's accuracy (test) : 0.9643\nroot feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4438\nroot - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0189\nroot - Down - Down, - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([312]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([3]))\nroot - Up, - Leaf class=0 belief= 0.951989 impurity=0.0914 counts=(array([0, 1]), array([694, 35]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9665\nClassifier's accuracy (test) : 0.9621\nroot feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4438\nroot - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0250\nroot - Down - Down, - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([312]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([4]))\nroot - Up, - Leaf class=0 belief= 0.951923 impurity=0.0915 counts=(array([0, 1]), array([693, 35]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9703\nClassifier's accuracy (test) : 0.9665\nroot feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4438\nroot - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0367\nroot - Down - Down, - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([315]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([6]))\nroot - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0846\nroot - Up - Down, - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([1]))\nroot - Up - Up, - Leaf class=0 belief= 0.957064 impurity=0.0822 counts=(array([0, 1]), array([691, 31]))\n\n**************************************************\n0.4375 secs\n" } ], "source": [ @@ -227,7 +227,7 @@ { "output_type": "stream", "name": "stdout", - "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([304]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([8]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([4]))\nroot - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([3]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Up - Up - Up - Up, - Leaf class=0 belief= 0.963225 counts=(array([0, 1]), array([681, 26]))\n" + "text": "root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4438\nroot - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0367\nroot - Down - Down, - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([315]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([6]))\nroot - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0846\nroot - Up - Down, - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([1]))\nroot - Up - Up, - Leaf class=0 belief= 0.957064 impurity=0.0822 counts=(array([0, 1]), array([691, 31]))\n" } ], "source": [ @@ -244,7 +244,7 @@ { "output_type": "stream", "name": "stdout", - "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([304]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([8]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([4]))\nroot - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([3]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down, - Leaf class=1 belief= 1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Up - Up - Up - Up, - Leaf class=0 belief= 0.963225 counts=(array([0, 1]), array([681, 26]))\n" + "text": "root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4438\nroot - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0367\nroot - Down - Down, - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([315]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([6]))\nroot - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0846\nroot - Up - Down, - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([1]))\nroot - Up - Up, - Leaf class=0 belief= 0.957064 impurity=0.0822 counts=(array([0, 1]), array([691, 31]))\n" } ], "source": [ @@ -268,7 +268,7 @@ { "output_type": "stream", "name": "stdout", - "text": "1 functools.partial(, 'Stree')\n2 functools.partial(, 'Stree')\n3 functools.partial(, 'Stree')\n4 functools.partial(, 'Stree')\n5 functools.partial(, 'Stree')\n6 functools.partial(, 'Stree')\n7 functools.partial(, 'Stree')\n8 functools.partial(, 'Stree')\n9 functools.partial(, 'Stree')\n10 functools.partial(, 'Stree', readonly_memmap=True)\n11 functools.partial(, 'Stree')\n12 functools.partial(, 'Stree')\n13 functools.partial(, 'Stree')\n14 functools.partial(, 'Stree')\n15 functools.partial(, 'Stree')\n16 functools.partial(, 'Stree')\n17 functools.partial(, 'Stree')\n18 functools.partial(, 'Stree')\n19 functools.partial(, 'Stree')\n20 functools.partial(, 'Stree')\n21 functools.partial(, 'Stree')\n22 functools.partial(, 'Stree')\n23 functools.partial(, 'Stree')\n24 functools.partial(, 'Stree', readonly_memmap=True)\n25 functools.partial(, 'Stree', readonly_memmap=True, X_dtype='float32')\n26 functools.partial(, 'Stree')\n27 functools.partial(, 'Stree')\n28 functools.partial(, 'Stree')\n29 functools.partial(, 'Stree')\n30 functools.partial(, 'Stree')\n31 functools.partial(, 'Stree')\n32 functools.partial(, 'Stree')\n33 functools.partial(, 'Stree')\n34 functools.partial(, 'Stree')\n35 functools.partial(, 'Stree')\n36 functools.partial(, 'Stree')\n37 functools.partial(, 'Stree')\n38 functools.partial(, 'Stree')\n39 functools.partial(, 'Stree')\n40 functools.partial(, 'Stree')\n41 functools.partial(, 'Stree')\n42 functools.partial(, 'Stree')\n43 functools.partial(, 'Stree')\n" + "text": "1 functools.partial(, 'Stree')\n2 functools.partial(, 'Stree')\n3 functools.partial(, 'Stree')\n4 functools.partial(, 'Stree')\n5 functools.partial(, 'Stree')\n6 functools.partial(, 'Stree')\n7 functools.partial(, 'Stree')\n8 functools.partial(, 'Stree')\n9 functools.partial(, 'Stree')\n10 functools.partial(, 'Stree', readonly_memmap=True)\n11 functools.partial(, 'Stree')\n12 functools.partial(, 'Stree')\n13 functools.partial(, 'Stree')\n14 functools.partial(, 'Stree')\n15 functools.partial(, 'Stree')\n16 functools.partial(, 'Stree')\n17 functools.partial(, 'Stree')\n18 functools.partial(, 'Stree')\n19 functools.partial(, 'Stree')\n20 functools.partial(, 'Stree')\n21 functools.partial(, 'Stree')\n22 functools.partial(, 'Stree')\n23 functools.partial(, 'Stree')\n24 functools.partial(, 'Stree', readonly_memmap=True)\n25 functools.partial(, 'Stree', readonly_memmap=True, X_dtype='float32')\n26 functools.partial(, 'Stree')\n27 functools.partial(, 'Stree')\n28 functools.partial(, 'Stree')\n29 functools.partial(, 'Stree')\n30 functools.partial(, 'Stree')\n31 functools.partial(, 'Stree')\n32 functools.partial(, 'Stree')\n33 functools.partial(, 'Stree')\n34 functools.partial(, 'Stree')\n35 functools.partial(, 'Stree')\n36 functools.partial(, 'Stree')\n37 functools.partial(, 'Stree')\n38 functools.partial(, 'Stree')\n39 functools.partial(, 'Stree')\n40 functools.partial(, 'Stree')\n41 functools.partial(, 'Stree')\n42 functools.partial(, 'Stree')\n43 functools.partial(, 'Stree')\n" } ], "source": [ @@ -306,7 +306,7 @@ { "output_type": "stream", "name": "stdout", - "text": "== Not Weighted ===\nSVC train score ..: 0.9521072796934866\nSTree train score : 0.9578544061302682\nSVC test score ...: 0.9553571428571429\nSTree test score .: 0.9575892857142857\n==== Weighted =====\nSVC train score ..: 0.9616858237547893\nSTree train score : 0.9616858237547893\nSVC test score ...: 0.9642857142857143\nSTree test score .: 0.9598214285714286\n*SVC test score ..: 0.951413553411694\n*STree test score : 0.9480517444389333\n" + "text": "== Not Weighted ===\nSVC train score ..: 0.9578544061302682\nSTree train score : 0.960727969348659\nSVC test score ...: 0.9508928571428571\nSTree test score .: 0.9553571428571429\n==== Weighted =====\nSVC train score ..: 0.9636015325670498\nSTree train score : 0.9626436781609196\nSVC test score ...: 0.9553571428571429\nSTree test score .: 0.9553571428571429\n*SVC test score ..: 0.9447820728419238\n*STree test score : 0.9447820728419238\n" } ], "source": [ @@ -338,12 +338,49 @@ { "output_type": "stream", "name": "stdout", - "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief= 0.969325 counts=(array([0, 1]), array([ 10, 316]))\nroot - Down - Up, - Leaf class=0 belief= 1.000000 counts=(array([0]), array([1]))\nroot - Up, - Leaf class=0 belief= 0.958159 counts=(array([0, 1]), array([687, 30]))\n\n" + "text": "root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4438\nroot - Down, - Leaf class=1 belief= 0.978261 impurity=0.0425 counts=(array([0, 1]), array([ 7, 315]))\nroot - Up, - Leaf class=0 belief= 0.955679 impurity=0.0847 counts=(array([0, 1]), array([690, 32]))\n\n" } ], "source": [ "print(clf)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test max_features" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "****************************************\nmax_features None = 28\nTrain score : 0.9664750957854407\nTest score .: 0.9642857142857143\nTook 0.09 seconds\n****************************************\nmax_features auto = 5\nTrain score : 0.9511494252873564\nTest score .: 0.9441964285714286\nTook 0.37 seconds\n****************************************\nmax_features log2 = 4\nTrain score : 0.935823754789272\nTest score .: 0.9330357142857143\nTook 0.10 seconds\n****************************************\nmax_features 7 = 7\nTrain score : 0.9568965517241379\nTest score .: 0.9397321428571429\nTook 3.36 seconds\n****************************************\nmax_features 0.5 = 14\nTrain score : 0.960727969348659\nTest score .: 0.9486607142857143\nTook 112.42 seconds\n****************************************\nmax_features 0.1 = 2\nTrain score : 0.8793103448275862\nTest score .: 0.8839285714285714\nTook 0.06 seconds\n****************************************\nmax_features 0.7 = 19\nTrain score : 0.9655172413793104\nTest score .: 0.9553571428571429\nTook 10.59 seconds\n" + } + ], + "source": [ + "for max_features in [None, \"auto\", \"log2\", 7, .5, .1, .7]:\n", + " now = time.time()\n", + " print(\"*\"*40)\n", + " clf = Stree(random_state=random_state, max_features=max_features)\n", + " clf.fit(Xtrain, ytrain)\n", + " print(f\"max_features {max_features} = {clf.max_features_}\")\n", + " print(\"Train score :\", clf.score(Xtrain, ytrain))\n", + " print(\"Test score .:\", clf.score(Xtest, ytest))\n", + " print(f\"Took {time.time() - now:.2f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/stree/Strees.py b/stree/Strees.py index cb8731f..1a37b4e 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -205,7 +205,7 @@ class Stree(BaseEstimator, ClassifierMixin): the hyperplane of the node :rtype: np.array """ - return node._clf.decision_function(data) + return node._clf.decision_function(data[:, node._features]) def _min_distance(self, data: np.array, _) -> np.array: # chooses the lowest distance of every sample @@ -286,11 +286,14 @@ class Stree(BaseEstimator, ClassifierMixin): sample_weight = _check_sample_weight(sample_weight, X) check_classification_targets(y) # Initialize computed parameters + if self.random_state is not None: + random.seed(self.random_state) self.classes_, y = np.unique(y, return_inverse=True) self.n_classes_ = self.classes_.shape[0] self.n_iter_ = self.max_iter self.depth_ = 0 self.n_features_ = X.shape[1] + self.n_features_in_ = X.shape[1] self.max_features_ = self._initialize_max_features() self.criterion_function_ = getattr(self, f"_{self.criterion}") self.tree_ = self.train(X, y, sample_weight, 1, "root") @@ -336,12 +339,12 @@ class Stree(BaseEstimator, ClassifierMixin): ) # Train the model clf = self._build_clf() - Xs, indices_subset = self._get_subspace(X) + Xs, features = self._get_subspace(X) clf.fit(Xs, y, sample_weight=sample_weight) impurity = self.criterion_function_(y) - node = Snode(clf, X, y, indices_subset, impurity, title) + node = Snode(clf, X, y, features, impurity, title) self.depth_ = max(depth, self.depth_) - down = self._split_criteria(self._distances(node, Xs), node) + down = self._split_criteria(self._distances(node, X), node) X_U, X_D = self._split_array(X, down) y_u, y_d = self._split_array(y, down) sw_u, sw_d = self._split_array(sample_weight, down) @@ -439,6 +442,11 @@ class Stree(BaseEstimator, ClassifierMixin): check_is_fitted(self, ["tree_"]) # Input validation X = check_array(X) + if X.shape[1] != self.n_features_: + raise ValueError( + f"Expected {self.n_features_} features but got " + f"({X.shape[1]})" + ) # setup prediction & make it happen indices = np.arange(X.shape[0]) result = ( @@ -548,7 +556,7 @@ class Stree(BaseEstimator, ClassifierMixin): features = range(dataset.shape[1]) features_sets = list(combinations(features, self.max_features_)) if len(features_sets) > 1: - return features_sets[random.randint(0, len(features_sets))] + return features_sets[random.randint(0, len(features_sets) - 1)] else: return features_sets[0] diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index a3fb3d1..371e1d0 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -360,3 +360,17 @@ class Stree_test(unittest.TestCase): clf = Stree(criterion="entropy") clf.fit(*load_dataset()) self.assertEqual(expected, clf.criterion_function_(y)) + + def test_predict_feature_dimensions(self): + X = np.random.rand(10, 5) + y = np.random.randint(0, 2, 10) + clf = Stree() + clf.fit(X, y) + with self.assertRaises(ValueError): + clf.predict(X[:, :3]) + + def test_score_max_features(self): + X, y = load_dataset(self._random_state) + clf = Stree(random_state=self._random_state, max_features=2) + clf.fit(X, y) + self.assertAlmostEqual(0.9426666666666667, clf.score(X, y))