integrate iterator in Stree

This commit is contained in:
2020-05-19 18:19:23 +02:00
parent 95a6901f47
commit 6ebd0f9be3
5 changed files with 53 additions and 74 deletions

View File

@@ -1,22 +1,34 @@
'''
__author__ = "Ricardo Montañana Gómez"
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
__license__ = "MIT"
__version__ = "0.9"
Inorder iterator for the binary tree of Snodes
Uses LinearSVC
'''
from trees.Snode import Snode
class Siterator:
"""Implements an inorder iterator
"""Inorder iterator
"""
def __init__(self, tree: Snode):
self._stack = []
self._push(tree)
def hasNext(self) -> bool:
return len(self._stack) > 0
def __iter__(self):
return self
def _push(self, node: Snode):
while (node is not None):
self._stack.insert(0, node)
node = node.get_down()
def next(self) -> Snode:
def __next__(self) -> Snode:
if len(self._stack) == 0:
raise StopIteration()
node = self._stack.pop()
self._push(node.get_up())
return node

View File

@@ -65,6 +65,6 @@ class Snode:
def __str__(self) -> str:
if self.is_leaf():
return f"{self._title} - Leaf class={self._class} belief={self._belief:.6f} counts={np.unique(self._y, return_counts=True)}\n"
return f"{self._title} - Leaf class={self._class} belief={self._belief:.6f} counts={np.unique(self._y, return_counts=True)}"
else:
return f"{self._title}\n"
return f"{self._title}"

View File

@@ -1,4 +1,3 @@
# This Python file uses the following encoding: utf-8
'''
__author__ = "Ricardo Montañana Gómez"
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
@@ -16,13 +15,14 @@ from sklearn.svm import LinearSVC
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from trees.Snode import Snode
from trees.Siterator import Siterator
class Stree(BaseEstimator, ClassifierMixin):
"""
"""
def __init__(self, C=1.0, max_iter: int=1000, random_state: int=0, use_predictions: bool=False):
def __init__(self, C=1.0, max_iter: int = 1000, random_state: int = 0, use_predictions: bool = False):
self._max_iter = max_iter
self._C = C
self._random_state = random_state
@@ -184,28 +184,15 @@ class Stree(BaseEstimator, ClassifierMixin):
right = (yp == y).astype(int)
return np.sum(right) / len(y)
def __print_tree(self, tree: Snode, only_leaves=False) -> str:
if not only_leaves:
output = str(tree)
else:
output = ''
if tree.is_leaf():
if only_leaves:
output = str(tree)
return output
output += self.__print_tree(tree.get_down(), only_leaves)
output += self.__print_tree(tree.get_up(), only_leaves)
def __iter__(self):
return Siterator(self._tree)
def __str__(self) -> str:
output = ''
for i in self:
output += str(i) + '\n'
return output
def show_tree(self, only_leaves=False):
if only_leaves:
print(self.__print_tree(self._tree, only_leaves=True))
else:
print(self)
def __str__(self):
return self.__print_tree(self._tree)
def _save_datasets(self, tree: Snode, catalog: typing.TextIO, number: int):
"""Save the dataset of the node in a csv file
@@ -232,4 +219,4 @@ class Stree(BaseEstimator, ClassifierMixin):
"""Save the every dataset stored in the tree to check with manual classifier
"""
with open(self.get_catalog_name(), 'w', encoding='utf-8') as catalog:
self._save_datasets(self._tree, catalog, 1)
self._save_datasets(self._tree, catalog, 1)