From 6ebd0f9be30d5dba3143b6d4876cba55d2d8cdcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Tue, 19 May 2020 18:19:23 +0200 Subject: [PATCH] integrate iterator in Stree --- main.py | 1 + test2.ipynb | 67 ++++++++++++++++------------------------------ trees/Siterator.py | 22 +++++++++++---- trees/Snode.py | 4 +-- trees/Stree.py | 33 +++++++---------------- 5 files changed, 53 insertions(+), 74 deletions(-) diff --git a/main.py b/main.py index d4f045a..ab77139 100644 --- a/main.py +++ b/main.py @@ -48,6 +48,7 @@ print(clf) print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}") print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}") proba = clf.predict_proba(Xtest) +print("Checking that we have correct probabilities, these are probabilities of sample belonging to class 1") res0 = proba[proba[:, 0] == 0] res1 = proba[proba[:, 0] == 0] print("++++++++++res0++++++++++++") diff --git a/test2.ipynb b/test2.ipynb index 8350f64..dfa924f 100644 --- a/test2.ipynb +++ b/test2.ipynb @@ -35,7 +35,7 @@ { "output_type": "stream", "name": "stdout", - "text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 33.311% 497\nValid: 66.689% 995\n" + "text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 32.976% 492\nValid: 67.024% 1000\n" } ], "source": [ @@ -97,7 +97,7 @@ { "output_type": "stream", "name": "stdout", - "text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9626\nClassifier's accuracy (test) : 0.9487\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=0.978261 counts=(array([0, 1]), array([ 7, 315]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, - Leaf class=0 belief=0.955432 counts=(array([0, 1]), array([686, 32]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9636\nClassifier's accuracy (test) : 0.9509\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=0.987421 counts=(array([0, 1]), array([ 4, 314]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up, - Leaf class=0 belief=0.953039 counts=(array([0, 1]), array([690, 34]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9703\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([316]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up\nroot - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up, - Leaf class=0 belief=0.957064 counts=(array([0, 1]), array([691, 31]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9684\nClassifier's accuracy (test) : 0.9554\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([315]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([6]))\nroot - Up, - Leaf class=0 belief=0.954357 counts=(array([0, 1]), array([690, 33]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9761\nClassifier's accuracy (test) : 0.9464\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([316]))\nroot - Down - Up\nroot - Down - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([9]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up, - Leaf class=0 belief=0.964539 counts=(array([0, 1]), array([680, 25]))\n\n**************************************************\n0.4014 secs\n" + "text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9550\nClassifier's accuracy (test) : 0.9487\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=0.977346 counts=(array([0, 1]), array([ 7, 302]))\nroot - Up\nroot - Up - Down, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Up, - Leaf class=0 belief=0.945280 counts=(array([0, 1]), array([691, 40]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9569\nClassifier's accuracy (test) : 0.9576\nroot\nroot - Down, - Leaf class=1 belief=0.986971 counts=(array([0, 1]), array([ 4, 303]))\nroot - Up, - Leaf class=0 belief=0.944369 counts=(array([0, 1]), array([696, 41]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9674\nClassifier's accuracy (test) : 0.9554\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([310]))\nroot - Up, - Leaf class=0 belief=0.953232 counts=(array([0, 1]), array([693, 34]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9693\nClassifier's accuracy (test) : 0.9487\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([310]))\nroot - Up\nroot - Up - Down, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\nroot - Up - Up\nroot - Up - Up - Down, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Up - Up - Up, - Leaf class=0 belief=0.955494 counts=(array([0, 1]), array([687, 32]))\nroot - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9780\nClassifier's accuracy (test) : 0.9487\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([301]))\nroot - Up\nroot - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Down - Up\nroot - Down - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([15]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Down - Up - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([15]))\nroot - Up - Up - Up, - Leaf class=0 belief=0.967468 counts=(array([0, 1]), array([684, 23]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n\n**************************************************\n0.7277 secs\n" } ], "source": [ @@ -132,27 +132,18 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([316]))\nroot - Down - Up\nroot - Down - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([9]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up, - Leaf class=0 belief=0.964539 counts=(array([0, 1]), array([680, 25]))\n\n" + "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([301]))\nroot - Up\nroot - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Down - Up\nroot - Down - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([15]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Down - Up - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([15]))\nroot - Up - Up - Up, - Leaf class=0 belief=0.967468 counts=(array([0, 1]), array([684, 23]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n" } ], "source": [ - "print(clf)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "from trees.Siterator import Siterator\n", - "it = Siterator(clf._tree)" + "for i in list(clf):\n", + " print(i)" ] }, { @@ -163,46 +154,34 @@ { "output_type": "stream", "name": "stdout", - "text": "root\n\nroot - Down\n\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([316]))\n\nroot - Up\n\nroot - Up - Down\n\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\n\nroot - Down - Up\n\nroot - Down - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\n\nroot - Up - Up\n\nroot - Up - Up - Down\n\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\n\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\n\nroot - Down - Up - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([9]))\n\nroot - Up - Up - Up, - Leaf class=0 belief=0.964539 counts=(array([0, 1]), array([680, 25]))\n\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\n\n" + "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([301]))\nroot - Up\nroot - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Down - Up\nroot - Down - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([15]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Down - Up - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([15]))\nroot - Up - Up - Up, - Leaf class=0 belief=0.967468 counts=(array([0, 1]), array([684, 23]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n" } ], "source": [ - "while(it.hasNext()):\n", - " print(it.next())" + "for i in clf:\n", + " print(i)" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 11, "metadata": {}, "outputs": [ { - "output_type": "error", - "ename": "ImportError", - "evalue": "Failed to import any qt binding", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'matplotlib'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'qt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmpl_toolkits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmplot3d\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAxes3D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mticker\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLinearLocator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFormatStrFormatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_line_magic\u001b[0;34m(self, magic_name, line, _stack_depth)\u001b[0m\n\u001b[1;32m 2315\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'local_ns'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getframe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstack_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf_locals\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2316\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltin_trap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2317\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2318\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2319\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mmatplotlib\u001b[0;34m(self, line)\u001b[0m\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/IPython/core/magic.py\u001b[0m in \u001b[0;36m\u001b[0;34m(f, *a, **k)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;31m# but it's overkill for just that one bit of state.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmagic_deco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/IPython/core/magics/pylab.py\u001b[0m in \u001b[0;36mmatplotlib\u001b[0;34m(self, line)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Available matplotlib backends: %s\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mbackends_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_matplotlib\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_show_matplotlib_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36menable_matplotlib\u001b[0;34m(self, gui)\u001b[0m\n\u001b[1;32m 3417\u001b[0m \u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfind_gui_and_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpylab_gui_select\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3418\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3419\u001b[0;31m \u001b[0mpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactivate_matplotlib\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3420\u001b[0m \u001b[0mpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfigure_inline_support\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36mactivate_matplotlib\u001b[0;34m(backend)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[0;31m# when this function runs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;31m# So avoid needing matplotlib attribute-lookup to access pyplot.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 320\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 321\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mswitch_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/pyplot.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2280\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__setitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcParams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"backend\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrcsetup\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_auto_backend_sentinel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2281\u001b[0m \u001b[0;31m# Set up the backend.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2282\u001b[0;31m \u001b[0mswitch_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcParams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"backend\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2283\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2284\u001b[0m \u001b[0;31m# Just to be safe. Interactive mode can be turned on without\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/pyplot.py\u001b[0m in \u001b[0;36mswitch_backend\u001b[0;34m(newbackend)\u001b[0m\n\u001b[1;32m 219\u001b[0m else \"matplotlib.backends.backend_{}\".format(newbackend.lower()))\n\u001b[1;32m 220\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 221\u001b[0;31m \u001b[0mbackend_mod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbackend_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 222\u001b[0m Backend = type(\n\u001b[1;32m 223\u001b[0m \"Backend\", (matplotlib.backends._Backend,), vars(backend_mod))\n", - "\u001b[0;32m/usr/local/Cellar/python/3.7.6_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/backends/backend_qt5agg.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mbackend_agg\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mFigureCanvasAgg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m from .backend_qt5 import (\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0mQtCore\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQtGui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQtWidgets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_BackendQT5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureCanvasQT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureManagerQT\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m NavigationToolbar2QT, backend_version)\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/backends/backend_qt5.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0m_Backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureCanvasBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureManagerBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNavigationToolbar2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m TimerBase, cursors, ToolContainerBase, StatusbarBase, MouseButton)\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqt_editor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigureoptions\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfigureoptions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqt_editor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformsubplottool\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mUiSubplotTool\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend_managers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mToolManager\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/backends/qt_editor/figureoptions.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolors\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmcolors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmarkers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimage\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmimage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqt_compat\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mQtGui\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqt_editor\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0m_formlayout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/backends/qt_compat.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mImportError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Failed to import any qt binding\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 169\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# We should not get there.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Unexpected QT_API: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mQT_API\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mImportError\u001b[0m: Failed to import any qt binding" - ] + "output_type": "display_data", + "data": { + "text/plain": "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …", + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "0025f832c1734afc944021e5990c2d11" + } + }, + "metadata": {} } ], "source": [ - "%matplotlib qt\n", + "%matplotlib widget\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import matplotlib.pyplot as plt\n", "from matplotlib import cm\n", @@ -229,8 +208,8 @@ "ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))\n", "\n", "# rotate the axes and update\n", - "for angle in range(0, 360):\n", - " ax.view_init(30, 40)\n", + "#for angle in range(0, 360):\n", + "# ax.view_init(30, 40)\n", "\n", "# Add a color bar which maps values to colors.\n", "fig.colorbar(surf, shrink=0.5, aspect=5)\n", diff --git a/trees/Siterator.py b/trees/Siterator.py index baa6c03..c32820f 100644 --- a/trees/Siterator.py +++ b/trees/Siterator.py @@ -1,22 +1,34 @@ +''' +__author__ = "Ricardo Montañana Gómez" +__copyright__ = "Copyright 2020, Ricardo Montañana Gómez" +__license__ = "MIT" +__version__ = "0.9" +Inorder iterator for the binary tree of Snodes +Uses LinearSVC +''' from trees.Snode import Snode + class Siterator: - """Implements an inorder iterator + """Inorder iterator """ + def __init__(self, tree: Snode): self._stack = [] self._push(tree) - - def hasNext(self) -> bool: - return len(self._stack) > 0 + + def __iter__(self): + return self def _push(self, node: Snode): while (node is not None): self._stack.insert(0, node) node = node.get_down() - def next(self) -> Snode: + def __next__(self) -> Snode: + if len(self._stack) == 0: + raise StopIteration() node = self._stack.pop() self._push(node.get_up()) return node diff --git a/trees/Snode.py b/trees/Snode.py index 0aa3361..e7f1bf5 100644 --- a/trees/Snode.py +++ b/trees/Snode.py @@ -65,6 +65,6 @@ class Snode: def __str__(self) -> str: if self.is_leaf(): - return f"{self._title} - Leaf class={self._class} belief={self._belief:.6f} counts={np.unique(self._y, return_counts=True)}\n" + return f"{self._title} - Leaf class={self._class} belief={self._belief:.6f} counts={np.unique(self._y, return_counts=True)}" else: - return f"{self._title}\n" + return f"{self._title}" diff --git a/trees/Stree.py b/trees/Stree.py index 2fe1e53..e82bcfe 100644 --- a/trees/Stree.py +++ b/trees/Stree.py @@ -1,4 +1,3 @@ -# This Python file uses the following encoding: utf-8 ''' __author__ = "Ricardo Montañana Gómez" __copyright__ = "Copyright 2020, Ricardo Montañana Gómez" @@ -16,13 +15,14 @@ from sklearn.svm import LinearSVC from sklearn.utils.validation import check_X_y, check_array, check_is_fitted from trees.Snode import Snode +from trees.Siterator import Siterator class Stree(BaseEstimator, ClassifierMixin): """ """ - def __init__(self, C=1.0, max_iter: int=1000, random_state: int=0, use_predictions: bool=False): + def __init__(self, C=1.0, max_iter: int = 1000, random_state: int = 0, use_predictions: bool = False): self._max_iter = max_iter self._C = C self._random_state = random_state @@ -184,28 +184,15 @@ class Stree(BaseEstimator, ClassifierMixin): right = (yp == y).astype(int) return np.sum(right) / len(y) - def __print_tree(self, tree: Snode, only_leaves=False) -> str: - if not only_leaves: - output = str(tree) - else: - output = '' - if tree.is_leaf(): - if only_leaves: - output = str(tree) - return output - output += self.__print_tree(tree.get_down(), only_leaves) - output += self.__print_tree(tree.get_up(), only_leaves) + def __iter__(self): + return Siterator(self._tree) + + def __str__(self) -> str: + output = '' + for i in self: + output += str(i) + '\n' return output - def show_tree(self, only_leaves=False): - if only_leaves: - print(self.__print_tree(self._tree, only_leaves=True)) - else: - print(self) - - def __str__(self): - return self.__print_tree(self._tree) - def _save_datasets(self, tree: Snode, catalog: typing.TextIO, number: int): """Save the dataset of the node in a csv file @@ -232,4 +219,4 @@ class Stree(BaseEstimator, ClassifierMixin): """Save the every dataset stored in the tree to check with manual classifier """ with open(self.get_catalog_name(), 'w', encoding='utf-8') as catalog: - self._save_datasets(self._tree, catalog, 1) \ No newline at end of file + self._save_datasets(self._tree, catalog, 1)