mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
Grapher working
This commit is contained in:
7
main.py
7
main.py
@@ -50,9 +50,8 @@ print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}")
|
||||
proba = clf.predict_proba(Xtest)
|
||||
print("Checking that we have correct probabilities, these are probabilities of sample belonging to class 1")
|
||||
res0 = proba[proba[:, 0] == 0]
|
||||
res1 = proba[proba[:, 0] == 0]
|
||||
print("++++++++++res0++++++++++++")
|
||||
res1 = proba[proba[:, 0] == 1]
|
||||
print("++++++++++res0 > .8++++++++++++")
|
||||
print(res0[res0[:, 1] > .8])
|
||||
print("**********res1************")
|
||||
print("**********res1 < .4************")
|
||||
print(res1[res1[:, 1] < .4])
|
||||
print(clf.predict_proba(Xtest))
|
42
test2.ipynb
42
test2.ipynb
@@ -5,6 +5,21 @@
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#\n",
|
||||
"# Google Colab setup\n",
|
||||
"#\n",
|
||||
"#!git clone https://github.com/Doctorado-ML/STree.git\n",
|
||||
"# Set working dir to Stree\n",
|
||||
"#import os\n",
|
||||
"#os.chdir(\"STree\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
@@ -17,7 +32,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -29,13 +44,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 33.244% 496\nValid: 66.756% 996\n"
|
||||
"text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 32.976% 492\nValid: 67.024% 1000\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -84,20 +99,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9559\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=1 belief=0.980769 counts=(array([0, 1]), array([ 6, 306]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.945205 counts=(array([0, 1]), array([690, 40]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9588\nClassifier's accuracy (test) : 0.9509\nroot\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=1 belief=0.990323 counts=(array([0, 1]), array([ 3, 307]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.945205 counts=(array([0, 1]), array([690, 40]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9732\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([6]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([9]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.960839 counts=(array([0, 1]), array([687, 28]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9732\nClassifier's accuracy (test) : 0.9509\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([312]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([7]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.960784 counts=(array([0, 1]), array([686, 28]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9741\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down\nroot - Down - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([312]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([8]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Down - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.961756 counts=(array([0, 1]), array([679, 27]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\n\n**************************************************\n0.8116 secs\n"
|
||||
"text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9559\nClassifier's accuracy (test) : 0.9442\nroot\nroot - Down, <cgaf> - Leaf class=1 belief=0.986928 counts=(array([0, 1]), array([ 4, 302]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.943089 counts=(array([0, 1]), array([696, 42]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9588\nClassifier's accuracy (test) : 0.9397\nroot\nroot - Down, <cgaf> - Leaf class=1 belief=0.993443 counts=(array([0, 1]), array([ 2, 303]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.944520 counts=(array([0, 1]), array([698, 41]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9703\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.957359 counts=(array([0, 1]), array([696, 31]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9703\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.957182 counts=(array([0, 1]), array([693, 31]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9789\nClassifier's accuracy (test) : 0.9509\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up\nroot - Up - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.968481 counts=(array([0, 1]), array([676, 22]))\n\n**************************************************\n0.6609 secs\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -115,7 +123,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -132,13 +140,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "root\nroot - Down\nroot - Down - Down\nroot - Down - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([312]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([8]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Down - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.961756 counts=(array([0, 1]), array([679, 27]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\n"
|
||||
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up\nroot - Up - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.968481 counts=(array([0, 1]), array([676, 22]))\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -149,13 +157,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "root\nroot - Down\nroot - Down - Down\nroot - Down - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([312]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([8]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Down - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.961756 counts=(array([0, 1]), array([679, 27]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\n"
|
||||
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up - Up\nroot - Up - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([4]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Up - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.968481 counts=(array([0, 1]), array([676, 22]))\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
|
File diff suppressed because one or more lines are too long
@@ -144,15 +144,17 @@ class Stree_test(unittest.TestCase):
|
||||
"""Check that element 28 has a prediction different that the current label
|
||||
"""
|
||||
# Element 28 has a different prediction than the truth
|
||||
decimals = 8
|
||||
X, y = self._get_Xy()
|
||||
yp = self._clf.predict_proba(X[28, :].reshape(-1, X.shape[1]))
|
||||
self.assertEqual(0, yp[0:, 0])
|
||||
self.assertEqual(1, y[28])
|
||||
self.assertEqual(0.29026400766, round(yp[0, 1], 11))
|
||||
self.assertEqual(round(0.29026400766, decimals), round(yp[0, 1], decimals))
|
||||
|
||||
def test_multiple_predict_proba(self):
|
||||
# First 27 elements the predictions are the same as the truth
|
||||
num = 27
|
||||
decimals = 8
|
||||
X, y = self._get_Xy()
|
||||
yp = self._clf.predict_proba(X[:num, :])
|
||||
self.assertListEqual(y[:num].tolist(), yp[:, 0].tolist())
|
||||
@@ -161,7 +163,9 @@ class Stree_test(unittest.TestCase):
|
||||
0.30756427, 0.8318412, 0.18981198, 0.15564624, 0.25740655, 0.22923355,
|
||||
0.87365959, 0.49928689, 0.95574351, 0.28761257, 0.28906333, 0.32643692,
|
||||
0.29788483, 0.01657364, 0.81149083]
|
||||
self.assertListEqual(expected_proba, np.round(yp[:, 1], decimals=8).tolist())
|
||||
self.assertListEqual(
|
||||
np.round(expected_proba, decimals=decimals).tolist(),
|
||||
np.round(yp[:, 1], decimals=decimals).tolist())
|
||||
|
||||
def build_models(self):
|
||||
"""Build and train two models, model_clf will use the sklearn classifier to
|
||||
|
@@ -23,17 +23,26 @@ class Snode_graph(Snode):
|
||||
def set_plot_size(self, size):
|
||||
self._plot_size = size
|
||||
|
||||
def _is_pure(self) -> bool:
|
||||
"""is considered pure a leaf node with one label
|
||||
"""
|
||||
if self.is_leaf():
|
||||
return self._belief == 1.
|
||||
return False
|
||||
|
||||
def plot_hyperplane(self):
|
||||
# get the splitting hyperplane
|
||||
def hyperplane(x, y): return (-self._interceptor - self._vector[0][0] * x
|
||||
- self._vector[0][1] * y) / self._vector[0][2]
|
||||
fig = plt.figure(figsize=self._plot_size)
|
||||
ax = fig.add_subplot(1, 1, 1, projection='3d')
|
||||
tmpx = np.linspace(self._X[:, 0].min(), self._X[:, 0].max())
|
||||
tmpy = np.linspace(self._X[:, 1].min(), self._X[:, 1].max())
|
||||
xx, yy = np.meshgrid(tmpx, tmpy)
|
||||
ax.plot_surface(xx, yy, hyperplane(xx, yy), alpha=.5, antialiased=True,
|
||||
rstride=1, cstride=1, cmap='seismic')
|
||||
if not self._is_pure():
|
||||
# Can't plot hyperplane of leaves with one label because it hasn't classiffier
|
||||
tmpx = np.linspace(self._X[:, 0].min(), self._X[:, 0].max())
|
||||
tmpy = np.linspace(self._X[:, 1].min(), self._X[:, 1].max())
|
||||
xx, yy = np.meshgrid(tmpx, tmpy)
|
||||
ax.plot_surface(xx, yy, hyperplane(xx, yy), alpha=.5, antialiased=True,
|
||||
rstride=1, cstride=1, cmap='seismic')
|
||||
plt.title(self._title)
|
||||
self.plot_distribution(ax)
|
||||
return ax
|
||||
|
@@ -33,6 +33,9 @@ class Stree_grapher(Stree):
|
||||
|
||||
def _copy_tree(self, node: Snode) -> Snode_graph:
|
||||
mirror = Snode_graph(node)
|
||||
# clone node
|
||||
mirror._class = node._class
|
||||
mirror._belief = node._belief
|
||||
if node.get_down() is not None:
|
||||
mirror.set_down(self._copy_tree(node.get_down()))
|
||||
if node.get_up() is not None:
|
||||
|
Reference in New Issue
Block a user