mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
first approx to grapher
This commit is contained in:
51
trees/Stree_grapher.py
Normal file
51
trees/Stree_grapher.py
Normal file
@@ -0,0 +1,51 @@
|
||||
'''
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT"
|
||||
__version__ = "0.9"
|
||||
Plot 3D views of nodes in Stree
|
||||
'''
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from sklearn.decomposition import PCA
|
||||
import trees
|
||||
import matplotlib.pyplot as plt
|
||||
from trees.Snode import Snode
|
||||
from trees.Snode_graph import Snode_graph
|
||||
from trees.Stree import Stree
|
||||
from trees.Siterator import Siterator
|
||||
|
||||
class Stree_grapher(Stree):
|
||||
def __init__(self, params: dict):
|
||||
self._plot_size = (8, 8)
|
||||
self._tree_gr = None
|
||||
# make Snode store X's
|
||||
os.environ['TESTING'] = '1'
|
||||
super().__init__(**params)
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
os.environ.pop('TESTING')
|
||||
except:
|
||||
pass
|
||||
plt.close('all')
|
||||
|
||||
def _copy_tree(self, node: Snode) -> Snode_graph:
|
||||
mirror = Snode_graph(node)
|
||||
if node.get_down() is not None:
|
||||
mirror.set_down(self._copy_tree(node.get_down()))
|
||||
if node.get_up() is not None:
|
||||
mirror.set_up(self._copy_tree(node.get_up()))
|
||||
return mirror
|
||||
|
||||
def fit(self, X: np.array, y: np.array) -> Stree:
|
||||
if X.shape[1] != 3:
|
||||
pca = PCA(n_components=3)
|
||||
X = pca.fit_transform(X)
|
||||
res = super().fit(X, y)
|
||||
self._tree_gr = self._copy_tree(self._tree)
|
||||
return res
|
||||
|
||||
def __iter__(self):
|
||||
return Siterator(self._tree_gr)
|
Reference in New Issue
Block a user