Cosmetics and Siterator

This commit is contained in:
2020-05-19 16:38:59 +02:00
parent 68512b3d75
commit 95a6901f47
3 changed files with 151 additions and 50 deletions

View File

@@ -35,7 +35,7 @@
{
"output_type": "stream",
"name": "stdout",
"text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (284807, 28) y.shape (284807,)\nFraud: 0.173% 492\nValid: 99.827% 284315\n"
"text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 33.311% 497\nValid: 66.689% 995\n"
}
],
"source": [
@@ -74,7 +74,7 @@
"\n",
"# data = load_creditcard(-5000) # Take all true samples + 5000 of the others\n",
"# data = load_creditcard(5000) # Take the first 5000 samples\n",
"data = load_creditcard() # Take all the samples\n",
"data = load_creditcard(-1000) # Take all the samples\n",
"\n",
"Xtrain = data[0]\n",
"Xtest = data[1]\n",
@@ -82,6 +82,13 @@
"ytest = data[3]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
@@ -90,14 +97,19 @@
{
"output_type": "stream",
"name": "stdout",
"text": "root\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=1 belief=0.941799 counts=(array([0, 1]), array([ 11, 178]))\nroot - Down - Up\nroot - Down - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Down - Up - Up, <cgaf> - Leaf class=0 belief=0.952381 counts=(array([0, 1]), array([20, 1]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <cgaf> - Leaf class=1 belief=0.902174 counts=(array([0, 1]), array([ 9, 83]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([14]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.999598 counts=(array([0, 1]), array([198966, 80]))\n\n42.9141 secs\n"
"text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9626\nClassifier's accuracy (test) : 0.9487\nroot\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=1 belief=0.978261 counts=(array([0, 1]), array([ 7, 315]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.955432 counts=(array([0, 1]), array([686, 32]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9636\nClassifier's accuracy (test) : 0.9509\nroot\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=1 belief=0.987421 counts=(array([0, 1]), array([ 4, 314]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.953039 counts=(array([0, 1]), array([690, 34]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9703\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([316]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.957064 counts=(array([0, 1]), array([691, 31]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9684\nClassifier's accuracy (test) : 0.9554\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([315]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([6]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.954357 counts=(array([0, 1]), array([690, 33]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9761\nClassifier's accuracy (test) : 0.9464\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([316]))\nroot - Down - Up\nroot - Down - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([9]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.964539 counts=(array([0, 1]), array([680, 25]))\n\n**************************************************\n0.4014 secs\n"
}
],
"source": [
"t = time.time()\n",
"clf = Stree(C=.01, random_state=random_state)\n",
"clf.fit(Xtrain, ytrain)\n",
"print(clf)\n",
"for C in (.001, .01, 1, 5, 17):\n",
" clf = Stree(C=C, random_state=random_state)\n",
" clf.fit(Xtrain, ytrain)\n",
" print(f\"************** C={C} ****************************\")\n",
" print(f\"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}\")\n",
" print(f\"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}\")\n",
" print(clf)\n",
" print(f\"**************************************************\")\n",
"print(f\"{time.time() - t:.4f} secs\")"
]
},
@@ -105,17 +117,17 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n0.2084 secs\n"
}
],
"outputs": [],
"source": [
"t = time.time()\n",
"print(clf.predict(Xtest)[:17])\n",
"print(f\"{time.time() - t:.4f} secs\")"
"import numpy as np\n",
"from sklearn.preprocessing import StandardScaler\n",
"from sklearn.svm import LinearSVC\n",
"from sklearn.calibration import CalibratedClassifierCV\n",
"scaler = StandardScaler()\n",
"cclf = CalibratedClassifierCV(base_estimator=LinearSVC(), cv=5)\n",
"cclf.fit(Xtrain, ytrain)\n",
"res = cclf.predict_proba(Xtest)\n",
"#an array containing probabilities of belonging to the 1st class"
]
},
{
@@ -126,31 +138,21 @@
{
"output_type": "stream",
"name": "stdout",
"text": "[[0. 0.26356965]\n [0. 0.22665372]\n [0. 0.25678353]\n [0. 0.26056019]\n [0. 0.26583006]\n [0. 0.24360041]\n [0. 0.26366182]\n [0. 0.26012045]\n [0. 0.2298345 ]\n [0. 0.25726294]\n [0. 0.25909988]\n [0. 0.25940575]\n [0. 0.24256254]\n [0. 0.15094485]\n [0. 0.26327588]\n [0. 0.26382949]\n [0. 0.26290957]]\n0.2083 secs\n"
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([316]))\nroot - Down - Up\nroot - Down - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([9]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\nroot - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.964539 counts=(array([0, 1]), array([680, 25]))\n\n"
}
],
"source": [
"t = time.time()\n",
"print(clf.predict_proba(Xtest)[:17, :])\n",
"print(f\"{time.time() - t:.4f} secs\")"
"print(clf)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "Classifier's accuracy (train): 0.9995\nClassifier's accuracy (test) : 0.9995\n0.5074 secs\n"
}
],
"outputs": [],
"source": [
"t = time.time()\n",
"print(f\"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}\")\n",
"print(f\"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}\")\n",
"print(f\"{time.time() - t:.4f} secs\")"
"from trees.Siterator import Siterator\n",
"it = Siterator(clf._tree)"
]
},
{
@@ -161,35 +163,87 @@
{
"output_type": "stream",
"name": "stdout",
"text": "0.9993211848834895\n13.3835 secs\n"
"text": "root\n\nroot - Down\n\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([316]))\n\nroot - Up\n\nroot - Up - Down\n\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\n\nroot - Down - Up\n\nroot - Down - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\n\nroot - Up - Up\n\nroot - Up - Up - Down\n\nroot - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\n\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\n\nroot - Down - Up - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([9]))\n\nroot - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.964539 counts=(array([0, 1]), array([680, 25]))\n\nroot - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([3]))\n\n"
}
],
"source": [
"t = time.time()\n",
"clf2 = LinearSVC(C=.01, random_state=random_state)\n",
"clf2.fit(Xtrain, ytrain)\n",
"print(clf2.score(Xtest, ytest))\n",
"print(f\"{time.time() - t:.4f} secs\")"
"while(it.hasNext()):\n",
" print(it.next())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "0.9991573329588147\n22.7635 secs\n"
"output_type": "error",
"ename": "ImportError",
"evalue": "Failed to import any qt binding",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-f39b24919c22>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'matplotlib'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'qt'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmpl_toolkits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmplot3d\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAxes3D\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mticker\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mLinearLocator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFormatStrFormatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_line_magic\u001b[0;34m(self, magic_name, line, _stack_depth)\u001b[0m\n\u001b[1;32m 2315\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'local_ns'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_getframe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstack_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf_locals\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2316\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltin_trap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2317\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2318\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2319\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<decorator-gen-108>\u001b[0m in \u001b[0;36mmatplotlib\u001b[0;34m(self, line)\u001b[0m\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/IPython/core/magic.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(f, *a, **k)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;31m# but it's overkill for just that one bit of state.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmagic_deco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/IPython/core/magics/pylab.py\u001b[0m in \u001b[0;36mmatplotlib\u001b[0;34m(self, line)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Available matplotlib backends: %s\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mbackends_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_matplotlib\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_show_matplotlib_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36menable_matplotlib\u001b[0;34m(self, gui)\u001b[0m\n\u001b[1;32m 3417\u001b[0m \u001b[0mgui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfind_gui_and_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpylab_gui_select\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3418\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3419\u001b[0;31m \u001b[0mpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactivate_matplotlib\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3420\u001b[0m \u001b[0mpt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfigure_inline_support\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/IPython/core/pylabtools.py\u001b[0m in \u001b[0;36mactivate_matplotlib\u001b[0;34m(backend)\u001b[0m\n\u001b[1;32m 318\u001b[0m \u001b[0;31m# when this function runs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 319\u001b[0m \u001b[0;31m# So avoid needing matplotlib attribute-lookup to access pyplot.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 320\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 321\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 322\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mswitch_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/pyplot.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2280\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__setitem__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcParams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"backend\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrcsetup\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_auto_backend_sentinel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2281\u001b[0m \u001b[0;31m# Set up the backend.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2282\u001b[0;31m \u001b[0mswitch_backend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcParams\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"backend\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2283\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2284\u001b[0m \u001b[0;31m# Just to be safe. Interactive mode can be turned on without\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/pyplot.py\u001b[0m in \u001b[0;36mswitch_backend\u001b[0;34m(newbackend)\u001b[0m\n\u001b[1;32m 219\u001b[0m else \"matplotlib.backends.backend_{}\".format(newbackend.lower()))\n\u001b[1;32m 220\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 221\u001b[0;31m \u001b[0mbackend_mod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimport_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbackend_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 222\u001b[0m Backend = type(\n\u001b[1;32m 223\u001b[0m \"Backend\", (matplotlib.backends._Backend,), vars(backend_mod))\n",
"\u001b[0;32m/usr/local/Cellar/python/3.7.6_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/importlib/__init__.py\u001b[0m in \u001b[0;36mimport_module\u001b[0;34m(name, package)\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 127\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_bootstrap\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gcd_import\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlevel\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 128\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/backends/backend_qt5agg.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mbackend_agg\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mFigureCanvasAgg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m from .backend_qt5 import (\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0mQtCore\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQtGui\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mQtWidgets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_BackendQT5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureCanvasQT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureManagerQT\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m NavigationToolbar2QT, backend_version)\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/backends/backend_qt5.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0m_Backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureCanvasBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mFigureManagerBase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mNavigationToolbar2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m TimerBase, cursors, ToolContainerBase, StatusbarBase, MouseButton)\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqt_editor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigureoptions\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mfigureoptions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqt_editor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformsubplottool\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mUiSubplotTool\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackend_managers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mToolManager\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/backends/qt_editor/figureoptions.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcbook\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcolors\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmcolors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmarkers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimage\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmimage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqt_compat\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mQtGui\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqt_editor\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0m_formlayout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/pyblique/venv/lib/python3.7/site-packages/matplotlib/backends/qt_compat.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 168\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mImportError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Failed to import any qt binding\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 169\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# We should not get there.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Unexpected QT_API: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mQT_API\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mImportError\u001b[0m: Failed to import any qt binding"
]
}
],
"source": [
"t = time.time()\n",
"clf3 = DecisionTreeClassifier(random_state=random_state)\n",
"clf3.fit(Xtrain, ytrain)\n",
"print(clf3.score(Xtest, ytest))\n",
"print(f\"{time.time() - t:.4f} secs\")"
"%matplotlib qt\n",
"from mpl_toolkits.mplot3d import Axes3D\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import cm\n",
"from matplotlib.ticker import LinearLocator, FormatStrFormatter\n",
"import numpy as np\n",
"\n",
"fig = plt.figure()\n",
"ax = fig.gca(projection='3d')\n",
"\n",
"scale = 8\n",
"# Make data.\n",
"X = np.arange(-scale, scale, 0.25)\n",
"Y = np.arange(-scale, scale, 0.25)\n",
"X, Y = np.meshgrid(X, Y)\n",
"Z = X**2 + Y**2\n",
"\n",
"# Plot the surface.\n",
"surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,\n",
" linewidth=0, antialiased=False)\n",
"\n",
"# Customize the z axis.\n",
"ax.set_zlim(0, 100)\n",
"ax.zaxis.set_major_locator(LinearLocator(10))\n",
"ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))\n",
"\n",
"# rotate the axes and update\n",
"for angle in range(0, 360):\n",
" ax.view_init(30, 40)\n",
"\n",
"# Add a color bar which maps values to colors.\n",
"fig.colorbar(surf, shrink=0.5, aspect=5)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

22
trees/Siterator.py Normal file
View File

@@ -0,0 +1,22 @@
from trees.Snode import Snode
class Siterator:
"""Implements an inorder iterator
"""
def __init__(self, tree: Snode):
self._stack = []
self._push(tree)
def hasNext(self) -> bool:
return len(self._stack) > 0
def _push(self, node: Snode):
while (node is not None):
self._stack.insert(0, node)
node = node.get_down()
def next(self) -> Snode:
node = self._stack.pop()
self._push(node.get_up())
return node

View File

@@ -131,7 +131,26 @@ class Stree(BaseEstimator, ClassifierMixin):
return self._reorder_results(*predict_class(X, indices, self._tree))
def predict_proba(self, X: np.array) -> np.array:
"""Computes an approximation of the probability of samples belonging to class 1
(nothing more, nothing less)
:param X: dataset
:type X: np.array
"""
def predict_class(xp: np.array, indices: np.array, dist: np.array, node: Snode) -> np.array:
"""Run the tree to compute predictions
:param xp: subdataset of samples
:type xp: np.array
:param indices: indices of subdataset samples to rebuild original order
:type indices: np.array
:param dist: distances of every sample to the hyperplane or the father node
:type dist: np.array
:param node: node of the leaf with the class
:type node: Snode
:return: array of labels and distances, array of indices
:rtype: np.array
"""
if xp is None:
return [], []
if node.is_leaf():
@@ -151,11 +170,14 @@ class Stree(BaseEstimator, ClassifierMixin):
indices = np.arange(X.shape[0])
result, indices = predict_class(X, indices, [], self._tree)
result = result.reshape(X.shape[0], 2)
# Sigmoidize distance like in sklearn based on Platt(1999)
# Turn distances to hyperplane into probabilities based on fitting distances
# of samples to its hyperplane that classified them, to the sigmoid function
result[:, 1] = 1 / (1 + np.exp(-result[:, 1]))
return self._reorder_results(result, indices)
def score(self, X: np.array, y: np.array) -> float:
"""Return accuracy
"""
if not self.__trained:
self.fit(X, y)
yp = self.predict(X).reshape(y.shape)
@@ -187,9 +209,12 @@ class Stree(BaseEstimator, ClassifierMixin):
def _save_datasets(self, tree: Snode, catalog: typing.TextIO, number: int):
"""Save the dataset of the node in a csv file
Arguments:
tree {Snode} -- node with data to save
number {int} -- a number to make different file names
:param tree: node with data to save
:type tree: Snode
:param catalog: catalog file handler
:type catalog: typing.TextIO
:param number: sequential number for the generated file name
:type number: int
"""
data = np.append(tree._X, tree._y.reshape(-1, 1), axis=1)
name = f"{self.__folder}dataset{number}.csv"