Files
stree/trees/Stree_grapher.py

102 lines
2.9 KiB
Python

'''
__author__ = "Ricardo Montañana Gómez"
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
__license__ = "MIT"
__version__ = "0.9"
Plot 3D views of nodes in Stree
'''
import os
import numpy as np
from sklearn.decomposition import PCA
import trees
import matplotlib.pyplot as plt
from trees.Snode import Snode
from trees.Snode_graph import Snode_graph
from trees.Stree import Stree
from trees.Siterator import Siterator
class Stree_grapher(Stree):
"""Build 3d graphs of any dataset, if it's more than 3 features PCA shall
make its magic
"""
def __init__(self, params: dict):
self._plot_size = (8, 8)
self._tree_gr = None
# make Snode store X's
os.environ['TESTING'] = '1'
self._fitted = False
self._pca = None
super().__init__(**params)
def __del__(self):
try:
os.environ.pop('TESTING')
except:
pass
plt.close('all')
def _copy_tree(self, node: Snode) -> Snode_graph:
mirror = Snode_graph(node)
# clone node
mirror._class = node._class
mirror._belief = node._belief
if node.get_down() is not None:
mirror.set_down(self._copy_tree(node.get_down()))
if node.get_up() is not None:
mirror.set_up(self._copy_tree(node.get_up()))
return mirror
def fit(self, X: np.array, y: np.array) -> Stree:
"""Fit the Stree and copy the tree in a Snode_graph tree
:param X: Dataset
:type X: np.array
:param y: Labels
:type y: np.array
:return: Stree model
:rtype: Stree
"""
if X.shape[1] != 3:
self._pca = PCA(n_components=3)
X = self._pca.fit_transform(X)
res = super().fit(X, y)
self._tree_gr = self._copy_tree(self._tree)
self._fitted = True
return res
def score(self, X: np.array, y: np.array) -> float:
self._check_fitted()
if X.shape[1] != 3:
X = self._pca.transform(X)
return super().score(X, y)
def _check_fitted(self):
if not self._fitted:
raise Exception('Have to fit the grapher first!')
def save_all(self, save_folder: str = './', save_prefix: str = ''):
"""Save all the node plots in png format, each with a sequence number
:param save_folder: folder where the plots are saved, defaults to './'
:type save_folder: str, optional
"""
self._check_fitted()
seq = 1
for node in self:
node.save_hyperplane(save_folder=save_folder,
save_prefix=save_prefix, save_seq=seq)
seq += 1
def plot_all(self):
"""Plots all the nodes
"""
self._check_fitted()
for node in self:
node.plot_hyperplane()
def __iter__(self):
return Siterator(self._tree_gr)