mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 07:26:01 +00:00
96 lines
3.2 KiB
Python
96 lines
3.2 KiB
Python
'''
|
|
__author__ = "Ricardo Montañana Gómez"
|
|
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
|
__license__ = "MIT"
|
|
__version__ = "0.9"
|
|
Plot 3D views of nodes in Stree
|
|
'''
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
from trees.Snode import Snode
|
|
from trees.Stree import Stree
|
|
|
|
|
|
class Snode_graph(Snode):
|
|
|
|
def __init__(self, node: Stree):
|
|
self._plot_size = (8, 8)
|
|
self._xlimits = (None, None)
|
|
self._ylimits = (None, None)
|
|
self._zlimits = (None, None)
|
|
n = Snode.copy(node)
|
|
super().__init__(n._clf, n._X, n._y, n._title)
|
|
|
|
def set_plot_size(self, size: tuple):
|
|
self._plot_size = size
|
|
|
|
def _is_pure(self) -> bool:
|
|
"""is considered pure a leaf node with one label
|
|
"""
|
|
if self.is_leaf():
|
|
return self._belief == 1.
|
|
return False
|
|
|
|
def set_axis_limits(self, limits: tuple):
|
|
self._xlimits = limits[0]
|
|
self._ylimits = limits[1]
|
|
self._zlimits = limits[2]
|
|
|
|
def _set_graphics_axis(self, ax: Axes3D):
|
|
ax.set_xlim(self._xlimits)
|
|
ax.set_ylim(self._ylimits)
|
|
ax.set_zlim(self._zlimits)
|
|
|
|
def save_hyperplane(self, save_folder: str = './', save_prefix: str = '', save_seq: int = 1):
|
|
_, fig = self.plot_hyperplane()
|
|
name = f"{save_folder}{save_prefix}STnode{save_seq}.png"
|
|
fig.savefig(name, bbox_inches='tight')
|
|
plt.close(fig)
|
|
|
|
def _get_cmap(self):
|
|
cmap = 'jet'
|
|
if self._is_pure():
|
|
if self._class == 1:
|
|
cmap = 'jet_r'
|
|
return cmap
|
|
|
|
def _graph_title(self):
|
|
n_class, card = np.unique(self._y, return_counts=True)
|
|
return f"{self._title} {n_class} {card}"
|
|
|
|
def plot_hyperplane(self, plot_distribution: bool = True):
|
|
fig = plt.figure(figsize=self._plot_size)
|
|
ax = fig.add_subplot(1, 1, 1, projection='3d')
|
|
if not self._is_pure():
|
|
# Can't plot hyperplane of leaves with one label because it hasn't classiffier
|
|
# get the splitting hyperplane
|
|
def hyperplane(x, y): return (-self._interceptor - self._vector[0][0] * x
|
|
- self._vector[0][1] * y) / self._vector[0][2]
|
|
tmpx = np.linspace(self._X[:, 0].min(), self._X[:, 0].max())
|
|
tmpy = np.linspace(self._X[:, 1].min(), self._X[:, 1].max())
|
|
xx, yy = np.meshgrid(tmpx, tmpy)
|
|
ax.plot_surface(xx, yy, hyperplane(xx, yy), alpha=.5, antialiased=True,
|
|
rstride=1, cstride=1, cmap='seismic')
|
|
self._set_graphics_axis(ax)
|
|
if plot_distribution:
|
|
self.plot_distribution(ax)
|
|
else:
|
|
plt.title(self._graph_title())
|
|
plt.show()
|
|
return ax, fig
|
|
|
|
def plot_distribution(self, ax: Axes3D = None):
|
|
if ax is None:
|
|
fig = plt.figure(figsize=self._plot_size)
|
|
ax = fig.add_subplot(1, 1, 1, projection='3d')
|
|
plt.title(self._graph_title())
|
|
cmap = self._get_cmap()
|
|
ax.scatter(self._X[:, 0], self._X[:, 1],
|
|
self._X[:, 2], c=self._y, cmap=cmap)
|
|
ax.set_xlabel('X0')
|
|
ax.set_ylabel('X1')
|
|
ax.set_zlabel('X2')
|
|
plt.show()
|