mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 08:26:00 +00:00
Show sample_weight use in test2 notebook
Update revision to RC4 Lint Stree grapher
This commit is contained in:
@@ -15,6 +15,7 @@ from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
from .Strees import Stree, Snode, Siterator
|
||||
|
||||
|
||||
class Snode_graph(Snode):
|
||||
|
||||
def __init__(self, node: Stree):
|
||||
@@ -45,7 +46,8 @@ class Snode_graph(Snode):
|
||||
ax.set_ylim(self._ylimits)
|
||||
ax.set_zlim(self._zlimits)
|
||||
|
||||
def save_hyperplane(self, save_folder: str = './', save_prefix: str = '', save_seq: int = 1):
|
||||
def save_hyperplane(self, save_folder: str = './', save_prefix: str = '',
|
||||
save_seq: int = 1):
|
||||
_, fig = self.plot_hyperplane()
|
||||
name = f"{save_folder}{save_prefix}STnode{save_seq}.png"
|
||||
fig.savefig(name, bbox_inches='tight')
|
||||
@@ -53,9 +55,8 @@ class Snode_graph(Snode):
|
||||
|
||||
def _get_cmap(self):
|
||||
cmap = 'jet'
|
||||
if self._is_pure():
|
||||
if self._class == 1:
|
||||
cmap = 'jet_r'
|
||||
if self._is_pure() and self._class == 1:
|
||||
cmap = 'jet_r'
|
||||
return cmap
|
||||
|
||||
def _graph_title(self):
|
||||
@@ -66,16 +67,20 @@ class Snode_graph(Snode):
|
||||
fig = plt.figure(figsize=self._plot_size)
|
||||
ax = fig.add_subplot(1, 1, 1, projection='3d')
|
||||
if not self._is_pure():
|
||||
# Can't plot hyperplane of leaves with one label because it hasn't classiffier
|
||||
# Can't plot hyperplane of leaves with one label because it hasn't
|
||||
# classiffier
|
||||
# get the splitting hyperplane
|
||||
def hyperplane(x, y): return (-self._interceptor - self._vector[0][0] * x
|
||||
- self._vector[0][1] * y) / self._vector[0][2]
|
||||
def hyperplane(x, y): return (-self._interceptor
|
||||
- self._vector[0][0] * x
|
||||
- self._vector[0][1] * y) \
|
||||
/ self._vector[0][2]
|
||||
|
||||
tmpx = np.linspace(self._X[:, 0].min(), self._X[:, 0].max())
|
||||
tmpy = np.linspace(self._X[:, 1].min(), self._X[:, 1].max())
|
||||
xx, yy = np.meshgrid(tmpx, tmpy)
|
||||
ax.plot_surface(xx, yy, hyperplane(xx, yy), alpha=.5, antialiased=True,
|
||||
rstride=1, cstride=1, cmap='seismic')
|
||||
ax.plot_surface(xx, yy, hyperplane(xx, yy), alpha=.5,
|
||||
antialiased=True, rstride=1, cstride=1,
|
||||
cmap='seismic')
|
||||
self._set_graphics_axis(ax)
|
||||
if plot_distribution:
|
||||
self.plot_distribution(ax)
|
||||
@@ -97,6 +102,7 @@ class Snode_graph(Snode):
|
||||
ax.set_zlabel('X2')
|
||||
plt.show()
|
||||
|
||||
|
||||
class Stree_grapher(Stree):
|
||||
"""Build 3d graphs of any dataset, if it's more than 3 features PCA shall
|
||||
make its magic
|
||||
@@ -114,7 +120,7 @@ class Stree_grapher(Stree):
|
||||
def __del__(self):
|
||||
try:
|
||||
os.environ.pop('TESTING')
|
||||
except:
|
||||
except KeyError:
|
||||
pass
|
||||
plt.close('all')
|
||||
|
||||
@@ -181,4 +187,3 @@ class Stree_grapher(Stree):
|
||||
|
||||
def __iter__(self):
|
||||
return Siterator(self._tree_gr)
|
||||
|
||||
|
Reference in New Issue
Block a user