mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 07:26:01 +00:00
102 lines
2.9 KiB
Python
102 lines
2.9 KiB
Python
'''
|
|
__author__ = "Ricardo Montañana Gómez"
|
|
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
|
__license__ = "MIT"
|
|
__version__ = "0.9"
|
|
Plot 3D views of nodes in Stree
|
|
'''
|
|
|
|
import os
|
|
import numpy as np
|
|
from sklearn.decomposition import PCA
|
|
import trees
|
|
import matplotlib.pyplot as plt
|
|
from trees.Snode import Snode
|
|
from trees.Snode_graph import Snode_graph
|
|
from trees.Stree import Stree
|
|
from trees.Siterator import Siterator
|
|
|
|
|
|
class Stree_grapher(Stree):
|
|
"""Build 3d graphs of any dataset, if it's more than 3 features PCA shall
|
|
make its magic
|
|
"""
|
|
|
|
def __init__(self, params: dict):
|
|
self._plot_size = (8, 8)
|
|
self._tree_gr = None
|
|
# make Snode store X's
|
|
os.environ['TESTING'] = '1'
|
|
self._fitted = False
|
|
self._pca = None
|
|
super().__init__(**params)
|
|
|
|
def __del__(self):
|
|
try:
|
|
os.environ.pop('TESTING')
|
|
except:
|
|
pass
|
|
plt.close('all')
|
|
|
|
def _copy_tree(self, node: Snode) -> Snode_graph:
|
|
mirror = Snode_graph(node)
|
|
# clone node
|
|
mirror._class = node._class
|
|
mirror._belief = node._belief
|
|
if node.get_down() is not None:
|
|
mirror.set_down(self._copy_tree(node.get_down()))
|
|
if node.get_up() is not None:
|
|
mirror.set_up(self._copy_tree(node.get_up()))
|
|
return mirror
|
|
|
|
def fit(self, X: np.array, y: np.array) -> Stree:
|
|
"""Fit the Stree and copy the tree in a Snode_graph tree
|
|
|
|
:param X: Dataset
|
|
:type X: np.array
|
|
:param y: Labels
|
|
:type y: np.array
|
|
:return: Stree model
|
|
:rtype: Stree
|
|
"""
|
|
if X.shape[1] != 3:
|
|
self._pca = PCA(n_components=3)
|
|
X = self._pca.fit_transform(X)
|
|
res = super().fit(X, y)
|
|
self._tree_gr = self._copy_tree(self._tree)
|
|
self._fitted = True
|
|
return res
|
|
|
|
def score(self, X: np.array, y: np.array) -> float:
|
|
self._check_fitted()
|
|
if X.shape[1] != 3:
|
|
X = self._pca.transform(X)
|
|
return super().score(X, y)
|
|
|
|
def _check_fitted(self):
|
|
if not self._fitted:
|
|
raise Exception('Have to fit the grapher first!')
|
|
|
|
def save_all(self, save_folder: str = './', save_prefix: str = ''):
|
|
"""Save all the node plots in png format, each with a sequence number
|
|
|
|
:param save_folder: folder where the plots are saved, defaults to './'
|
|
:type save_folder: str, optional
|
|
"""
|
|
self._check_fitted()
|
|
seq = 1
|
|
for node in self:
|
|
node.save_hyperplane(save_folder=save_folder,
|
|
save_prefix=save_prefix, save_seq=seq)
|
|
seq += 1
|
|
|
|
def plot_all(self):
|
|
"""Plots all the nodes
|
|
"""
|
|
self._check_fitted()
|
|
for node in self:
|
|
node.plot_hyperplane()
|
|
|
|
def __iter__(self):
|
|
return Siterator(self._tree_gr)
|