diff --git a/README.md b/README.md index 713d726..af50d56 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,15 @@ # STree + Oblique Tree classifier based on SVM nodes ## Example +## Jupyter + +[Binder link to test.ipynb](https://gesis.mybinder.org/binder/v2/gh/Doctorado-ML/STree/80b5cf8e722a0fc648624fbab69cd686847f906d) + +## Command line + ```python python main.py ``` diff --git a/main.py b/main.py index ad7598d..4f28db5 100644 --- a/main.py +++ b/main.py @@ -35,7 +35,7 @@ def load_creditcard(n_examples=0): return X, y #X, y = load_creditcard(-5000) #X, y = load_creditcard() -X, y = load_creditcard() +#X, y = load_creditcard() clf = Stree(C=.01, max_iter=100, random_state=random_state) clf.fit(X, y) @@ -46,3 +46,4 @@ yp = clf.predict_proba(X[0, :].reshape(-1, X.shape[1])) print(f"Predicting {y[0]} we have {yp[0, 0]} with {yp[0, 1]} of belief") print(f"Classifier's accuracy: {clf.score(X, y, print_out=False):.4f}") clf.show_tree(only_leaves=True) +print(clf.predict_proba(X)) diff --git a/requirements.txt b/requirements.txt index a243d29..89d20fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ numpy==1.18.2 -scikit-learn==0.22.2 \ No newline at end of file +scikit-learn==0.22.2 +pandas==1.0.3 \ No newline at end of file diff --git a/test.ipynb b/test.ipynb index beee72b..abb8b15 100644 --- a/test.ipynb +++ b/test.ipynb @@ -20,6 +20,25 @@ "execution_count": 2, "metadata": {}, "outputs": [], + "source": [ + "import os, urllib.request\n", + "file_name = 'data/creditcard.csv'\n", + "if not os.path.isfile(file_name):\n", + " url = 'https://datahub.io/machine-learning/creditcard/r/creditcard.csv'\n", + " urllib.request.urlretrieve(url, file_name)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "*Original Fraud: 0.173% 492\n*Original Valid: 99.827% 284315\nX.shape (284807, 28) y.shape (284807, 1)\n-Generated Fraud: 0.173% 492\n-Generated Valid: 99.827% 284315\n" + } + ], "source": [ "def load_creditcard(n_examples=0):\n", " df = pd.read_csv('data/creditcard.csv')\n", @@ -62,13 +81,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "root\nroot - Down\nLeaf class=0 belief=0.999638 counts=(array([0, 1]), array([284242, 103]))\nroot - Down - Up\nroot - Down - Up - Down\nLeaf class=0 belief=0.857143 counts=(array([0, 1]), array([18, 3]))\nLeaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up - Up\nLeaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nLeaf class=1 belief=0.862069 counts=(array([0, 1]), array([ 16, 100]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down\nLeaf class=0 belief=0.920000 counts=(array([0, 1]), array([23, 2]))\nLeaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nLeaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nLeaf class=1 belief=0.948980 counts=(array([0, 1]), array([ 15, 279]))\n\n60.9873 secs\n" + "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=0 belief=0.999638 counts=(array([0, 1]), array([284242, 103]))\nroot - Down - Up\nroot - Down - Up - Down\nroot - Down - Up - Down - Down, - Leaf class=0 belief=0.857143 counts=(array([0, 1]), array([18, 3]))\nroot - Down - Up - Down - Up, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up - Up\nroot - Down - Up - Up - Down, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Down - Up - Up - Up, - Leaf class=1 belief=0.862069 counts=(array([0, 1]), array([ 16, 100]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down\nroot - Up - Down - Down - Down, - Leaf class=0 belief=0.920000 counts=(array([0, 1]), array([23, 2]))\nroot - Up - Down - Down - Up, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Up, - Leaf class=1 belief=0.948980 counts=(array([0, 1]), array([ 15, 279]))\n\n41.5053 secs\n" } ], "source": [ @@ -81,13 +100,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "Accuracy: 0.999512\n0.3226 secs\n" + "text": "Accuracy: 0.999512\n0.2389 secs\n" } ], "source": [ @@ -98,13 +117,86 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 15, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "(284807, 2)\n0.4148 secs\n" + "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=0 belief=0.999638 counts=(array([0, 1]), array([284242, 103]))\nroot - Down - Up\nroot - Down - Up - Down\nroot - Down - Up - Down - Down, - Leaf class=0 belief=0.857143 counts=(array([0, 1]), array([18, 3]))\nroot - Down - Up - Down - Up, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up - Up\nroot - Down - Up - Up - Down, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Down - Up - Up - Up, - Leaf class=1 belief=0.862069 counts=(array([0, 1]), array([ 16, 100]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down\nroot - Up - Down - Down - Down, - Leaf class=0 belief=0.920000 counts=(array([0, 1]), array([23, 2]))\nroot - Up - Down - Down - Up, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Up, - Leaf class=1 belief=0.948980 counts=(array([0, 1]), array([ 15, 279]))\n\n" + } + ], + "source": [ + "print(clf)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "root - Down - Up - Down - Up, \n" + }, + { + "output_type": "error", + "ename": "AttributeError", + "evalue": "'NoneType' object has no attribute 'coef_'", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mprintvec\u001b[0;34m(node)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_leaf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mprintvec\u001b[0;34m(node)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mprintvec\u001b[0;34m(node)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_leaf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mprintvec\u001b[0;34m(node)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mprintvec\u001b[0;34m(node)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_vector\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_title\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_clf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcoef_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_leaf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'coef_'" + ] + } + ], + "source": [ + "def printvec(node):\n", + " if node._vector is None:\n", + " print(node._title)\n", + " print(node._clf.coef_)\n", + " if node.is_leaf():\n", + " return\n", + " printvec(node.get_down())\n", + " printvec(node.get_up())\n", + "printvec(clf._tree)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "[[ 0.01272453 -0.00869691 0.00261326 0.013119 0.00200886 -0.00177167\n -0.0059729 -0.01137639 -0.01553846 -0.02868033 0.01536246 -0.02983645\n -0.00347147 -0.03159505 -0.0020723 -0.02189612 -0.03363292 -0.01214814\n 0.00413005 -0.00207565 0.00679726 0.00138744 -0.00159652 0.00035469\n -0.00397646 -0.00277537 0.01026267 0.00918578]] root - Down - Down, [-1.02881772]\n[[ 0.04405501 0.06724601 0.07316834 0.01356261 -0.01851765 0.00286983\n 0.01749505 0.02031383 0.00613145 0.00963547 -0.02055478 0.01969194\n -0.03824876 0.02122459 0.07457148 0.00967706 0.04688633 -0.00361551\n 0.00927978 0.01916087 -0.01800233 -0.00642752 0.01972691 -0.00539432\n -0.02530949 -0.00125887 -0.00248638 -0.01560347]] root - Down - Up - Down - Down, [-0.01578117]\nNone root - Down - Up - Down - Up, 0.0\n" + }, + { + "output_type": "error", + "ename": "TypeError", + "evalue": "'NoneType' object is not subscriptable", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0myp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{time.time() - t:.4f} secs\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_proba\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__proba\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m \u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_predict_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 141\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__proba\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36m_predict_values\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;31m# setup prediction & make it happen\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 127\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_reorder_results\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_class\u001b[0;34m(xp, indices, node)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mprediction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_class\u001b[0;34m(xp, indices, node)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 118\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 119\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;31m# sklearn check\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_class\u001b[0;34m(xp, indices, node)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mprediction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_class\u001b[0;34m(xp, indices, node)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 118\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 119\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;31m# sklearn check\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_class\u001b[0;34m(xp, indices, node)\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;31m#prediction_proba = np.full((xp.shape[0], 1), node._belief)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_vector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_title\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_interceptor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 112\u001b[0;31m \u001b[0mprediction_proba\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_linear_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 113\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprediction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprediction_proba\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36m_linear_function\u001b[0;34m(self, data, node)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_linear_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mSnode\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0mcoef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_vector\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_interceptor\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mTypeError\u001b[0m: 'NoneType' object is not subscriptable" + ] } ], "source": [ @@ -125,13 +217,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "0.9991397683343457\n20.9481 secs\n" + "text": "0.9991397683343457\n13.6326 secs\n" } ], "source": [ @@ -144,13 +236,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "1.0\n32.2779 secs\n" + "text": "1.0\n18.8308 secs\n" } ], "source": [ @@ -169,6 +261,62 @@ "clf = Stree()\n", "check_estimator(clf)" ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": "array([0.36514837, 0.73029674, 1.09544512, 1.46059349])" + }, + "metadata": {}, + "execution_count": 9 + } + ], + "source": [ + "import numpy as np\n", + "a = np.array([-1, 2, 3, -4])\n", + "(np.abs(a) - a.mean()) / a.std()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "[-1.05990681 -1.05181137 -0.99985686 -1.04574888] [[0.74267274 0.25732726]\n [0.74112258 0.25887742]\n [0.73103044 0.26896956]\n [0.73995773 0.26004227]]\n" + }, + { + "output_type": "error", + "ename": "NameError", + "evalue": "name 'yp' is not defined", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_predict_proba_lr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'yp' is not defined" + ] + } + ], + "source": [ + "a = clf2.decision_function(X)\n", + "b = clf2._predict_proba_lr(X)\n", + "print(a[:4], b[:4])\n", + "print(yp[:4])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -186,8 +334,8 @@ }, "orig_nbformat": 2, "kernelspec": { - "name": "python37664bitstreevenva9e4a4efdc1042b6b577bd15fbe145ee", - "display_name": "Python 3.7.6 64-bit ('stree': venv)" + "name": "python37664bitgeneralvenvfbd0a23e74cf4e778460f5ffc6761f39", + "display_name": "Python 3.7.6 64-bit ('general': venv)" } }, "nbformat": 4, diff --git a/tests/Snode_test.py b/tests/Snode_test.py index 43609bd..9b40325 100644 --- a/tests/Snode_test.py +++ b/tests/Snode_test.py @@ -54,5 +54,20 @@ class Snode_test(unittest.TestCase): # Check Class class_computed = classes[card == max_card] self.assertEqual(class_computed, node._class) - check_leave(self._clf._tree) + + def test_nodes_coefs(self): + """Check if the nodes of the tree have the right attributes filled + """ + def run_tree(node: Snode): + if node._belief < 1: + # only exclude pure leaves + self.assertIsNotNone(node._clf) + self.assertIsNotNone(node._clf.coef_) + self.assertIsNotNone(node._vector) + self.assertIsNotNone(node._interceptor) + if node.is_leaf(): + return + run_tree(node.get_down()) + run_tree(node.get_up()) + run_tree(self._clf._tree) diff --git a/trees/Stree.py b/trees/Stree.py index 75d34ca..081b714 100644 --- a/trees/Stree.py +++ b/trees/Stree.py @@ -43,6 +43,10 @@ class Stree(BaseEstimator, ClassifierMixin): setattr(self, parameter, value) return self + def _linear_function(self, data: np.array, node: Snode) -> np.array: + coef = node._vector[0, :].reshape(-1, data.shape[1]) + return data.dot(coef.T) + node._interceptor[0] + def _split_data(self, node: Snode, data: np.ndarray, indices: np.ndarray) -> list: if self.__use_predictions: yp = node._clf.predict(data) @@ -50,8 +54,7 @@ class Stree(BaseEstimator, ClassifierMixin): else: # doesn't work with multiclass as each sample has to do inner product with its own coeficients # computes positition of every sample is w.r.t. the hyperplane - coef = node._vector[0, :].reshape(-1, data.shape[1]) - res = data.dot(coef.T) + node._interceptor[0] + res = self._linear_function(data, node) down = res > 0 up = ~down data_down = data[down[:, 0]] if any(down) else None @@ -105,6 +108,7 @@ class Stree(BaseEstimator, ClassifierMixin): prediction = np.full((xp.shape[0], 1), node._class) if self.__proba: prediction_proba = np.full((xp.shape[0], 1), node._belief) + #prediction_proba = self._linear_function(xp, node) return np.append(prediction, prediction_proba, axis=1), indices else: return prediction, indices @@ -134,7 +138,10 @@ class Stree(BaseEstimator, ClassifierMixin): self.__proba = True result, indices = self._predict_values(X) self.__proba = False - return self._reorder_results(result.reshape(X.shape[0], 2), indices) + result = result.reshape(X.shape[0], 2) + # Sigmoidize distance like in sklearn based on Platt(1999) + #result[:, 1] = 1 / (1 + np.exp(-result[:, 1])) + return self._reorder_results(result, indices) def score(self, X: np.array, y: np.array, print_out=True) -> float: if not self.__trained: