add _linear_function to Stree

First try to implement predict_proba with sigmoid based on Platt
This commit is contained in:
2020-05-16 20:34:03 +02:00
parent 80b5cf8e72
commit 6a8ef288ce
6 changed files with 197 additions and 18 deletions

View File

@@ -1,8 +1,15 @@
# STree
Oblique Tree classifier based on SVM nodes
## Example
## Jupyter
[Binder link to test.ipynb](https://gesis.mybinder.org/binder/v2/gh/Doctorado-ML/STree/80b5cf8e722a0fc648624fbab69cd686847f906d)
## Command line
```python
python main.py
```

View File

@@ -35,7 +35,7 @@ def load_creditcard(n_examples=0):
return X, y
#X, y = load_creditcard(-5000)
#X, y = load_creditcard()
X, y = load_creditcard()
#X, y = load_creditcard()
clf = Stree(C=.01, max_iter=100, random_state=random_state)
clf.fit(X, y)
@@ -46,3 +46,4 @@ yp = clf.predict_proba(X[0, :].reshape(-1, X.shape[1]))
print(f"Predicting {y[0]} we have {yp[0, 0]} with {yp[0, 1]} of belief")
print(f"Classifier's accuracy: {clf.score(X, y, print_out=False):.4f}")
clf.show_tree(only_leaves=True)
print(clf.predict_proba(X))

View File

@@ -1,2 +1,3 @@
numpy==1.18.2
scikit-learn==0.22.2
pandas==1.0.3

View File

@@ -20,6 +20,25 @@
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os, urllib.request\n",
"file_name = 'data/creditcard.csv'\n",
"if not os.path.isfile(file_name):\n",
" url = 'https://datahub.io/machine-learning/creditcard/r/creditcard.csv'\n",
" urllib.request.urlretrieve(url, file_name)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "*Original Fraud: 0.173% 492\n*Original Valid: 99.827% 284315\nX.shape (284807, 28) y.shape (284807, 1)\n-Generated Fraud: 0.173% 492\n-Generated Valid: 99.827% 284315\n"
}
],
"source": [
"def load_creditcard(n_examples=0):\n",
" df = pd.read_csv('data/creditcard.csv')\n",
@@ -62,13 +81,13 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "root\nroot - Down\nLeaf class=0 belief=0.999638 counts=(array([0, 1]), array([284242, 103]))\nroot - Down - Up\nroot - Down - Up - Down\nLeaf class=0 belief=0.857143 counts=(array([0, 1]), array([18, 3]))\nLeaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up - Up\nLeaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nLeaf class=1 belief=0.862069 counts=(array([0, 1]), array([ 16, 100]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down\nLeaf class=0 belief=0.920000 counts=(array([0, 1]), array([23, 2]))\nLeaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nLeaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nLeaf class=1 belief=0.948980 counts=(array([0, 1]), array([ 15, 279]))\n\n60.9873 secs\n"
"text": "root\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=0 belief=0.999638 counts=(array([0, 1]), array([284242, 103]))\nroot - Down - Up\nroot - Down - Up - Down\nroot - Down - Up - Down - Down, <cgaf> - Leaf class=0 belief=0.857143 counts=(array([0, 1]), array([18, 3]))\nroot - Down - Up - Down - Up, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up - Up\nroot - Down - Up - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Down - Up - Up - Up, <cgaf> - Leaf class=1 belief=0.862069 counts=(array([0, 1]), array([ 16, 100]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down\nroot - Up - Down - Down - Down, <cgaf> - Leaf class=0 belief=0.920000 counts=(array([0, 1]), array([23, 2]))\nroot - Up - Down - Down - Up, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Up, <cgaf> - Leaf class=1 belief=0.948980 counts=(array([0, 1]), array([ 15, 279]))\n\n41.5053 secs\n"
}
],
"source": [
@@ -81,13 +100,13 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "Accuracy: 0.999512\n0.3226 secs\n"
"text": "Accuracy: 0.999512\n0.2389 secs\n"
}
],
"source": [
@@ -98,13 +117,86 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "(284807, 2)\n0.4148 secs\n"
"text": "root\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=0 belief=0.999638 counts=(array([0, 1]), array([284242, 103]))\nroot - Down - Up\nroot - Down - Up - Down\nroot - Down - Up - Down - Down, <cgaf> - Leaf class=0 belief=0.857143 counts=(array([0, 1]), array([18, 3]))\nroot - Down - Up - Down - Up, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up - Up\nroot - Down - Up - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Down - Up - Up - Up, <cgaf> - Leaf class=1 belief=0.862069 counts=(array([0, 1]), array([ 16, 100]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down\nroot - Up - Down - Down - Down, <cgaf> - Leaf class=0 belief=0.920000 counts=(array([0, 1]), array([23, 2]))\nroot - Up - Down - Down - Up, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Up, <cgaf> - Leaf class=1 belief=0.948980 counts=(array([0, 1]), array([ 15, 279]))\n\n"
}
],
"source": [
"print(clf)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "root - Down - Up - Down - Up, <pure>\n"
},
{
"output_type": "error",
"ename": "AttributeError",
"evalue": "'NoneType' object has no attribute 'coef_'",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-14-9b2b3e4e625d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-14-9b2b3e4e625d>\u001b[0m in \u001b[0;36mprintvec\u001b[0;34m(node)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_leaf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-14-9b2b3e4e625d>\u001b[0m in \u001b[0;36mprintvec\u001b[0;34m(node)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-14-9b2b3e4e625d>\u001b[0m in \u001b[0;36mprintvec\u001b[0;34m(node)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_leaf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-14-9b2b3e4e625d>\u001b[0m in \u001b[0;36mprintvec\u001b[0;34m(node)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mprintvec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-14-9b2b3e4e625d>\u001b[0m in \u001b[0;36mprintvec\u001b[0;34m(node)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_vector\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_title\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_clf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcoef_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_leaf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'coef_'"
]
}
],
"source": [
"def printvec(node):\n",
" if node._vector is None:\n",
" print(node._title)\n",
" print(node._clf.coef_)\n",
" if node.is_leaf():\n",
" return\n",
" printvec(node.get_down())\n",
" printvec(node.get_up())\n",
"printvec(clf._tree)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "[[ 0.01272453 -0.00869691 0.00261326 0.013119 0.00200886 -0.00177167\n -0.0059729 -0.01137639 -0.01553846 -0.02868033 0.01536246 -0.02983645\n -0.00347147 -0.03159505 -0.0020723 -0.02189612 -0.03363292 -0.01214814\n 0.00413005 -0.00207565 0.00679726 0.00138744 -0.00159652 0.00035469\n -0.00397646 -0.00277537 0.01026267 0.00918578]] root - Down - Down, <cgaf> [-1.02881772]\n[[ 0.04405501 0.06724601 0.07316834 0.01356261 -0.01851765 0.00286983\n 0.01749505 0.02031383 0.00613145 0.00963547 -0.02055478 0.01969194\n -0.03824876 0.02122459 0.07457148 0.00967706 0.04688633 -0.00361551\n 0.00927978 0.01916087 -0.01800233 -0.00642752 0.01972691 -0.00539432\n -0.02530949 -0.00125887 -0.00248638 -0.01560347]] root - Down - Up - Down - Down, <cgaf> [-0.01578117]\nNone root - Down - Up - Down - Up, <pure> 0.0\n"
},
{
"output_type": "error",
"ename": "TypeError",
"evalue": "'NoneType' object is not subscriptable",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-aa9f79a0bb01>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0myp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{time.time() - t:.4f} secs\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_proba\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__proba\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 140\u001b[0;31m \u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_predict_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 141\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__proba\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36m_predict_values\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;31m# setup prediction & make it happen\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 125\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 126\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 127\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_reorder_results\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_class\u001b[0;34m(xp, indices, node)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mprediction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_class\u001b[0;34m(xp, indices, node)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 118\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 119\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;31m# sklearn check\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_class\u001b[0;34m(xp, indices, node)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mprediction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_class\u001b[0;34m(xp, indices, node)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_d\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_down\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 118\u001b[0;31m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredict_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi_u\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_up\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 119\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;31m# sklearn check\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36mpredict_class\u001b[0;34m(xp, indices, node)\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;31m#prediction_proba = np.full((xp.shape[0], 1), node._belief)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_vector\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_title\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_interceptor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 112\u001b[0;31m \u001b[0mprediction_proba\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_linear_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 113\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprediction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprediction_proba\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Code/STree/trees/Stree.py\u001b[0m in \u001b[0;36m_linear_function\u001b[0;34m(self, data, node)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_linear_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mSnode\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0mcoef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_vector\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnode\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_interceptor\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: 'NoneType' object is not subscriptable"
]
}
],
"source": [
@@ -125,13 +217,13 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "0.9991397683343457\n20.9481 secs\n"
"text": "0.9991397683343457\n13.6326 secs\n"
}
],
"source": [
@@ -144,13 +236,13 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "1.0\n32.2779 secs\n"
"text": "1.0\n18.8308 secs\n"
}
],
"source": [
@@ -169,6 +261,62 @@
"clf = Stree()\n",
"check_estimator(clf)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "array([0.36514837, 0.73029674, 1.09544512, 1.46059349])"
},
"metadata": {},
"execution_count": 9
}
],
"source": [
"import numpy as np\n",
"a = np.array([-1, 2, 3, -4])\n",
"(np.abs(a) - a.mean()) / a.std()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "[-1.05990681 -1.05181137 -0.99985686 -1.04574888] [[0.74267274 0.25732726]\n [0.74112258 0.25887742]\n [0.73103044 0.26896956]\n [0.73995773 0.26004227]]\n"
},
{
"output_type": "error",
"ename": "NameError",
"evalue": "name 'yp' is not defined",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-10-9a30c92e1d37>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf2\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_predict_proba_lr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'yp' is not defined"
]
}
],
"source": [
"a = clf2.decision_function(X)\n",
"b = clf2._predict_proba_lr(X)\n",
"print(a[:4], b[:4])\n",
"print(yp[:4])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -186,8 +334,8 @@
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python37664bitstreevenva9e4a4efdc1042b6b577bd15fbe145ee",
"display_name": "Python 3.7.6 64-bit ('stree': venv)"
"name": "python37664bitgeneralvenvfbd0a23e74cf4e778460f5ffc6761f39",
"display_name": "Python 3.7.6 64-bit ('general': venv)"
}
},
"nbformat": 4,

View File

@@ -54,5 +54,20 @@ class Snode_test(unittest.TestCase):
# Check Class
class_computed = classes[card == max_card]
self.assertEqual(class_computed, node._class)
check_leave(self._clf._tree)
def test_nodes_coefs(self):
"""Check if the nodes of the tree have the right attributes filled
"""
def run_tree(node: Snode):
if node._belief < 1:
# only exclude pure leaves
self.assertIsNotNone(node._clf)
self.assertIsNotNone(node._clf.coef_)
self.assertIsNotNone(node._vector)
self.assertIsNotNone(node._interceptor)
if node.is_leaf():
return
run_tree(node.get_down())
run_tree(node.get_up())
run_tree(self._clf._tree)

View File

@@ -43,6 +43,10 @@ class Stree(BaseEstimator, ClassifierMixin):
setattr(self, parameter, value)
return self
def _linear_function(self, data: np.array, node: Snode) -> np.array:
coef = node._vector[0, :].reshape(-1, data.shape[1])
return data.dot(coef.T) + node._interceptor[0]
def _split_data(self, node: Snode, data: np.ndarray, indices: np.ndarray) -> list:
if self.__use_predictions:
yp = node._clf.predict(data)
@@ -50,8 +54,7 @@ class Stree(BaseEstimator, ClassifierMixin):
else:
# doesn't work with multiclass as each sample has to do inner product with its own coeficients
# computes positition of every sample is w.r.t. the hyperplane
coef = node._vector[0, :].reshape(-1, data.shape[1])
res = data.dot(coef.T) + node._interceptor[0]
res = self._linear_function(data, node)
down = res > 0
up = ~down
data_down = data[down[:, 0]] if any(down) else None
@@ -105,6 +108,7 @@ class Stree(BaseEstimator, ClassifierMixin):
prediction = np.full((xp.shape[0], 1), node._class)
if self.__proba:
prediction_proba = np.full((xp.shape[0], 1), node._belief)
#prediction_proba = self._linear_function(xp, node)
return np.append(prediction, prediction_proba, axis=1), indices
else:
return prediction, indices
@@ -134,7 +138,10 @@ class Stree(BaseEstimator, ClassifierMixin):
self.__proba = True
result, indices = self._predict_values(X)
self.__proba = False
return self._reorder_results(result.reshape(X.shape[0], 2), indices)
result = result.reshape(X.shape[0], 2)
# Sigmoidize distance like in sklearn based on Platt(1999)
#result[:, 1] = 1 / (1 + np.exp(-result[:, 1]))
return self._reorder_results(result, indices)
def score(self, X: np.array, y: np.array, print_out=True) -> float:
if not self.__trained: