mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
first approx to grapher
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
numpy==1.18.2
|
||||
scikit-learn==0.22.2
|
||||
pandas==1.0.3
|
||||
pandas==1.0.3
|
||||
matplotlib==3.2.1
|
76
test2.ipynb
76
test2.ipynb
@@ -35,7 +35,7 @@
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 32.976% 492\nValid: 67.024% 1000\n"
|
||||
"text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 33.244% 496\nValid: 66.756% 996\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -97,7 +97,7 @@
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9550\nClassifier's accuracy (test) : 0.9487\nroot\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=1 belief=0.977346 counts=(array([0, 1]), array([ 7, 302]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.945280 counts=(array([0, 1]), array([691, 40]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9569\nClassifier's accuracy (test) : 0.9576\nroot\nroot - Down, <cgaf> - Leaf class=1 belief=0.986971 counts=(array([0, 1]), array([ 4, 303]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.944369 counts=(array([0, 1]), array([696, 41]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9674\nClassifier's accuracy (test) : 0.9554\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([310]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.953232 counts=(array([0, 1]), array([693, 34]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9693\nClassifier's accuracy (test) : 0.9487\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([310]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\nroot - Up - Up\nroot - Up - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.955494 counts=(array([0, 1]), array([687, 32]))\nroot - Up - Up - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9780\nClassifier's accuracy (test) : 0.9487\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([301]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Down - Up\nroot - Down - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([15]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Down - Up - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([15]))\nroot - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.967468 counts=(array([0, 1]), array([684, 23]))\nroot - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n\n**************************************************\n0.7277 secs\n"
|
||||
"text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9559\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=1 belief=0.980769 counts=(array([0, 1]), array([ 6, 306]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.945205 counts=(array([0, 1]), array([690, 40]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9588\nClassifier's accuracy (test) : 0.9509\nroot\nroot - Down\nroot - Down - Down, <cgaf> - Leaf class=1 belief=0.990323 counts=(array([0, 1]), array([ 3, 307]))\nroot - Up, <cgaf> - Leaf class=0 belief=0.945205 counts=(array([0, 1]), array([690, 40]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9732\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([6]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([9]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.960839 counts=(array([0, 1]), array([687, 28]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9732\nClassifier's accuracy (test) : 0.9509\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([312]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([7]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.960784 counts=(array([0, 1]), array([686, 28]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9741\nClassifier's accuracy (test) : 0.9531\nroot\nroot - Down\nroot - Down - Down\nroot - Down - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([312]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([8]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Down - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.961756 counts=(array([0, 1]), array([679, 27]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\n\n**************************************************\n0.8116 secs\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -132,97 +132,37 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([301]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Down - Up\nroot - Down - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([15]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Down - Up - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([15]))\nroot - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.967468 counts=(array([0, 1]), array([684, 23]))\nroot - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n"
|
||||
"text": "root\nroot - Down\nroot - Down - Down\nroot - Down - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([312]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([8]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Down - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.961756 counts=(array([0, 1]), array([679, 27]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#check iterator\n",
|
||||
"for i in list(clf):\n",
|
||||
" print(i)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([301]))\nroot - Up\nroot - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Down - Up\nroot - Down - Up - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([15]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Down - Up - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([15]))\nroot - Up - Up - Up, <cgaf> - Leaf class=0 belief=0.967468 counts=(array([0, 1]), array([684, 23]))\nroot - Up - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\n"
|
||||
"text": "root\nroot - Down\nroot - Down - Down\nroot - Down - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([312]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief=1.000000 counts=(array([1]), array([8]))\nroot - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([10]))\nroot - Down - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief=0.961756 counts=(array([0, 1]), array([679, 27]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#check iterator again\n",
|
||||
"for i in clf:\n",
|
||||
" print(i)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "display_data",
|
||||
"data": {
|
||||
"text/plain": "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …",
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"version_major": 2,
|
||||
"version_minor": 0,
|
||||
"model_id": "0025f832c1734afc944021e5990c2d11"
|
||||
}
|
||||
},
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%matplotlib widget\n",
|
||||
"from mpl_toolkits.mplot3d import Axes3D\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from matplotlib import cm\n",
|
||||
"from matplotlib.ticker import LinearLocator, FormatStrFormatter\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"fig = plt.figure()\n",
|
||||
"ax = fig.gca(projection='3d')\n",
|
||||
"\n",
|
||||
"scale = 8\n",
|
||||
"# Make data.\n",
|
||||
"X = np.arange(-scale, scale, 0.25)\n",
|
||||
"Y = np.arange(-scale, scale, 0.25)\n",
|
||||
"X, Y = np.meshgrid(X, Y)\n",
|
||||
"Z = X**2 + Y**2\n",
|
||||
"\n",
|
||||
"# Plot the surface.\n",
|
||||
"surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,\n",
|
||||
" linewidth=0, antialiased=False)\n",
|
||||
"\n",
|
||||
"# Customize the z axis.\n",
|
||||
"ax.set_zlim(0, 100)\n",
|
||||
"ax.zaxis.set_major_locator(LinearLocator(10))\n",
|
||||
"ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))\n",
|
||||
"\n",
|
||||
"# rotate the axes and update\n",
|
||||
"#for angle in range(0, 360):\n",
|
||||
"# ax.view_init(30, 40)\n",
|
||||
"\n",
|
||||
"# Add a color bar which maps values to colors.\n",
|
||||
"fig.colorbar(surf, shrink=0.5, aspect=5)\n",
|
||||
"\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
394
test_graphs.ipynb
Normal file
394
test_graphs.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -15,7 +15,7 @@ class Snode_test(unittest.TestCase):
|
||||
self._clf = Stree(random_state=self._random_state,
|
||||
use_predictions=True)
|
||||
self._clf.fit(*self._get_Xy())
|
||||
super(Snode_test, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
@@ -16,7 +16,7 @@ class Stree_test(unittest.TestCase):
|
||||
self._clf = Stree(random_state=self._random_state,
|
||||
use_predictions=False)
|
||||
self._clf.fit(*self._get_Xy())
|
||||
super(Stree_test, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
@@ -217,6 +217,23 @@ class Stree_test(unittest.TestCase):
|
||||
#
|
||||
self.assertListEqual(yp_line.tolist(), yp_once.tolist())
|
||||
|
||||
def test_iterator(self):
|
||||
"""Check preorder iterator
|
||||
"""
|
||||
expected = [
|
||||
'root',
|
||||
'root - Down',
|
||||
'root - Down - Down, <cgaf> - Leaf class=1 belief=0.975989 counts=(array([0, 1]), array([ 17, 691]))',
|
||||
'root - Down - Up',
|
||||
'root - Down - Up - Down, <cgaf> - Leaf class=1 belief=0.750000 counts=(array([0, 1]), array([1, 3]))',
|
||||
'root - Down - Up - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))',
|
||||
'root - Up, <cgaf> - Leaf class=0 belief=0.928297 counts=(array([0, 1]), array([725, 56]))',
|
||||
]
|
||||
computed = []
|
||||
for node in self._clf:
|
||||
computed.append(str(node))
|
||||
self.assertListEqual(expected, computed)
|
||||
|
||||
|
||||
|
||||
|
||||
|
@@ -4,14 +4,13 @@ __copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT"
|
||||
__version__ = "0.9"
|
||||
Inorder iterator for the binary tree of Snodes
|
||||
Uses LinearSVC
|
||||
'''
|
||||
|
||||
from trees.Snode import Snode
|
||||
|
||||
|
||||
class Siterator:
|
||||
"""Inorder iterator
|
||||
"""Stree preorder iterator
|
||||
"""
|
||||
|
||||
def __init__(self, tree: Snode):
|
||||
@@ -22,13 +21,13 @@ class Siterator:
|
||||
return self
|
||||
|
||||
def _push(self, node: Snode):
|
||||
while (node is not None):
|
||||
self._stack.insert(0, node)
|
||||
node = node.get_down()
|
||||
if node is not None:
|
||||
self._stack.append(node)
|
||||
|
||||
def __next__(self) -> Snode:
|
||||
if len(self._stack) == 0:
|
||||
raise StopIteration()
|
||||
node = self._stack.pop()
|
||||
self._push(node.get_up())
|
||||
self._push(node.get_down())
|
||||
return node
|
||||
|
@@ -11,7 +11,6 @@ import os
|
||||
import numpy as np
|
||||
from sklearn.svm import LinearSVC
|
||||
|
||||
|
||||
class Snode:
|
||||
def __init__(self, clf: LinearSVC, X: np.ndarray, y: np.ndarray, title: str):
|
||||
self._clf = clf
|
||||
@@ -26,6 +25,10 @@ class Snode:
|
||||
self._up = None
|
||||
self._class = None
|
||||
|
||||
@classmethod
|
||||
def copy(cls, node: 'Snode') -> 'Snode':
|
||||
return cls(node._clf, node._X, node._y, node._title)
|
||||
|
||||
def set_down(self, son):
|
||||
self._down = son
|
||||
|
||||
@@ -45,9 +48,6 @@ class Snode:
|
||||
"""Compute the class of the predictor and its belief based on the subdataset of the node
|
||||
only if it is a leaf
|
||||
"""
|
||||
# Clean memory
|
||||
#self._X = None
|
||||
#self._y = None
|
||||
if not self.is_leaf():
|
||||
return
|
||||
classes, card = np.unique(self._y, return_counts=True)
|
||||
@@ -67,4 +67,4 @@ class Snode:
|
||||
if self.is_leaf():
|
||||
return f"{self._title} - Leaf class={self._class} belief={self._belief:.6f} counts={np.unique(self._y, return_counts=True)}"
|
||||
else:
|
||||
return f"{self._title}"
|
||||
return f"{self._title}"
|
49
trees/Snode_graph.py
Normal file
49
trees/Snode_graph.py
Normal file
@@ -0,0 +1,49 @@
|
||||
'''
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT"
|
||||
__version__ = "0.9"
|
||||
Plot 3D views of nodes in Stree
|
||||
'''
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
from trees.Snode import Snode
|
||||
from trees.Stree import Stree
|
||||
|
||||
|
||||
class Snode_graph(Snode):
|
||||
|
||||
def __init__(self, node: Stree):
|
||||
self._plot_size = (8, 8)
|
||||
n = Snode.copy(node)
|
||||
super().__init__(n._clf, n._X, n._y, n._title)
|
||||
|
||||
def set_plot_size(self, size):
|
||||
self._plot_size = size
|
||||
|
||||
def plot_hyperplane(self):
|
||||
# get the splitting hyperplane
|
||||
def hyperplane(x, y): return (-self._interceptor - self._vector[0][0] * x
|
||||
- self._vector[0][1] * y) / self._vector[0][2]
|
||||
fig = plt.figure(figsize=self._plot_size)
|
||||
ax = fig.add_subplot(1, 1, 1, projection='3d')
|
||||
tmpx = np.linspace(self._X[:, 0].min(), self._X[:, 0].max())
|
||||
tmpy = np.linspace(self._X[:, 1].min(), self._X[:, 1].max())
|
||||
xx, yy = np.meshgrid(tmpx, tmpy)
|
||||
ax.plot_surface(xx, yy, hyperplane(xx, yy), alpha=.5, antialiased=True,
|
||||
rstride=1, cstride=1, cmap='seismic')
|
||||
plt.title(self._title)
|
||||
self.plot_distribution(ax)
|
||||
return ax
|
||||
|
||||
def plot_distribution(self, ax: Axes3D = None):
|
||||
if ax is None:
|
||||
fig = plt.figure(figsize=self._plot_size)
|
||||
ax = fig.add_subplot(1, 1, 1, projection='3d')
|
||||
ax.scatter(self._X[:, 0], self._X[:, 1], self._X[:, 2], c=self._y)
|
||||
ax.set_xlabel('X0')
|
||||
ax.set_ylabel('X1')
|
||||
ax.set_zlabel('X2')
|
||||
plt.show()
|
@@ -13,7 +13,6 @@ import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.svm import LinearSVC
|
||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||
|
||||
from trees.Snode import Snode
|
||||
from trees.Siterator import Siterator
|
||||
|
||||
@@ -22,7 +21,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, C=1.0, max_iter: int = 1000, random_state: int = 0, use_predictions: bool = False):
|
||||
def __init__(self, C: float=1.0, max_iter: int = 1000, random_state: int = 0, use_predictions: bool = False):
|
||||
self._max_iter = max_iter
|
||||
self._C = C
|
||||
self._random_state = random_state
|
||||
|
51
trees/Stree_grapher.py
Normal file
51
trees/Stree_grapher.py
Normal file
@@ -0,0 +1,51 @@
|
||||
'''
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT"
|
||||
__version__ = "0.9"
|
||||
Plot 3D views of nodes in Stree
|
||||
'''
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from sklearn.decomposition import PCA
|
||||
import trees
|
||||
import matplotlib.pyplot as plt
|
||||
from trees.Snode import Snode
|
||||
from trees.Snode_graph import Snode_graph
|
||||
from trees.Stree import Stree
|
||||
from trees.Siterator import Siterator
|
||||
|
||||
class Stree_grapher(Stree):
|
||||
def __init__(self, params: dict):
|
||||
self._plot_size = (8, 8)
|
||||
self._tree_gr = None
|
||||
# make Snode store X's
|
||||
os.environ['TESTING'] = '1'
|
||||
super().__init__(**params)
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
os.environ.pop('TESTING')
|
||||
except:
|
||||
pass
|
||||
plt.close('all')
|
||||
|
||||
def _copy_tree(self, node: Snode) -> Snode_graph:
|
||||
mirror = Snode_graph(node)
|
||||
if node.get_down() is not None:
|
||||
mirror.set_down(self._copy_tree(node.get_down()))
|
||||
if node.get_up() is not None:
|
||||
mirror.set_up(self._copy_tree(node.get_up()))
|
||||
return mirror
|
||||
|
||||
def fit(self, X: np.array, y: np.array) -> Stree:
|
||||
if X.shape[1] != 3:
|
||||
pca = PCA(n_components=3)
|
||||
X = pca.fit_transform(X)
|
||||
res = super().fit(X, y)
|
||||
self._tree_gr = self._copy_tree(self._tree)
|
||||
return res
|
||||
|
||||
def __iter__(self):
|
||||
return Siterator(self._tree_gr)
|
Reference in New Issue
Block a user