Show sample_weight use in test2 notebook

Update revision to RC4
Lint Stree grapher
This commit is contained in:
2020-05-30 23:59:40 +02:00
parent 5e5fea9c6a
commit b4816b2995
3 changed files with 55 additions and 58 deletions

View File

@@ -15,6 +15,7 @@ from mpl_toolkits.mplot3d import Axes3D
from .Strees import Stree, Snode, Siterator
class Snode_graph(Snode):
def __init__(self, node: Stree):
@@ -45,7 +46,8 @@ class Snode_graph(Snode):
ax.set_ylim(self._ylimits)
ax.set_zlim(self._zlimits)
def save_hyperplane(self, save_folder: str = './', save_prefix: str = '', save_seq: int = 1):
def save_hyperplane(self, save_folder: str = './', save_prefix: str = '',
save_seq: int = 1):
_, fig = self.plot_hyperplane()
name = f"{save_folder}{save_prefix}STnode{save_seq}.png"
fig.savefig(name, bbox_inches='tight')
@@ -53,9 +55,8 @@ class Snode_graph(Snode):
def _get_cmap(self):
cmap = 'jet'
if self._is_pure():
if self._class == 1:
cmap = 'jet_r'
if self._is_pure() and self._class == 1:
cmap = 'jet_r'
return cmap
def _graph_title(self):
@@ -66,16 +67,20 @@ class Snode_graph(Snode):
fig = plt.figure(figsize=self._plot_size)
ax = fig.add_subplot(1, 1, 1, projection='3d')
if not self._is_pure():
# Can't plot hyperplane of leaves with one label because it hasn't classiffier
# Can't plot hyperplane of leaves with one label because it hasn't
# classiffier
# get the splitting hyperplane
def hyperplane(x, y): return (-self._interceptor - self._vector[0][0] * x
- self._vector[0][1] * y) / self._vector[0][2]
def hyperplane(x, y): return (-self._interceptor
- self._vector[0][0] * x
- self._vector[0][1] * y) \
/ self._vector[0][2]
tmpx = np.linspace(self._X[:, 0].min(), self._X[:, 0].max())
tmpy = np.linspace(self._X[:, 1].min(), self._X[:, 1].max())
xx, yy = np.meshgrid(tmpx, tmpy)
ax.plot_surface(xx, yy, hyperplane(xx, yy), alpha=.5, antialiased=True,
rstride=1, cstride=1, cmap='seismic')
ax.plot_surface(xx, yy, hyperplane(xx, yy), alpha=.5,
antialiased=True, rstride=1, cstride=1,
cmap='seismic')
self._set_graphics_axis(ax)
if plot_distribution:
self.plot_distribution(ax)
@@ -97,6 +102,7 @@ class Snode_graph(Snode):
ax.set_zlabel('X2')
plt.show()
class Stree_grapher(Stree):
"""Build 3d graphs of any dataset, if it's more than 3 features PCA shall
make its magic
@@ -114,7 +120,7 @@ class Stree_grapher(Stree):
def __del__(self):
try:
os.environ.pop('TESTING')
except:
except KeyError:
pass
plt.close('all')
@@ -181,4 +187,3 @@ class Stree_grapher(Stree):
def __iter__(self):
return Siterator(self._tree_gr)