mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 08:56:00 +00:00
integrate iterator in Stree
This commit is contained in:
@@ -1,22 +1,34 @@
|
||||
'''
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT"
|
||||
__version__ = "0.9"
|
||||
Inorder iterator for the binary tree of Snodes
|
||||
Uses LinearSVC
|
||||
'''
|
||||
|
||||
from trees.Snode import Snode
|
||||
|
||||
|
||||
class Siterator:
|
||||
"""Implements an inorder iterator
|
||||
"""Inorder iterator
|
||||
"""
|
||||
|
||||
def __init__(self, tree: Snode):
|
||||
self._stack = []
|
||||
self._push(tree)
|
||||
|
||||
def hasNext(self) -> bool:
|
||||
return len(self._stack) > 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _push(self, node: Snode):
|
||||
while (node is not None):
|
||||
self._stack.insert(0, node)
|
||||
node = node.get_down()
|
||||
|
||||
def next(self) -> Snode:
|
||||
def __next__(self) -> Snode:
|
||||
if len(self._stack) == 0:
|
||||
raise StopIteration()
|
||||
node = self._stack.pop()
|
||||
self._push(node.get_up())
|
||||
return node
|
||||
|
@@ -65,6 +65,6 @@ class Snode:
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.is_leaf():
|
||||
return f"{self._title} - Leaf class={self._class} belief={self._belief:.6f} counts={np.unique(self._y, return_counts=True)}\n"
|
||||
return f"{self._title} - Leaf class={self._class} belief={self._belief:.6f} counts={np.unique(self._y, return_counts=True)}"
|
||||
else:
|
||||
return f"{self._title}\n"
|
||||
return f"{self._title}"
|
||||
|
@@ -1,4 +1,3 @@
|
||||
# This Python file uses the following encoding: utf-8
|
||||
'''
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||
@@ -16,13 +15,14 @@ from sklearn.svm import LinearSVC
|
||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||
|
||||
from trees.Snode import Snode
|
||||
from trees.Siterator import Siterator
|
||||
|
||||
|
||||
class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, C=1.0, max_iter: int=1000, random_state: int=0, use_predictions: bool=False):
|
||||
def __init__(self, C=1.0, max_iter: int = 1000, random_state: int = 0, use_predictions: bool = False):
|
||||
self._max_iter = max_iter
|
||||
self._C = C
|
||||
self._random_state = random_state
|
||||
@@ -184,28 +184,15 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
right = (yp == y).astype(int)
|
||||
return np.sum(right) / len(y)
|
||||
|
||||
def __print_tree(self, tree: Snode, only_leaves=False) -> str:
|
||||
if not only_leaves:
|
||||
output = str(tree)
|
||||
else:
|
||||
output = ''
|
||||
if tree.is_leaf():
|
||||
if only_leaves:
|
||||
output = str(tree)
|
||||
return output
|
||||
output += self.__print_tree(tree.get_down(), only_leaves)
|
||||
output += self.__print_tree(tree.get_up(), only_leaves)
|
||||
def __iter__(self):
|
||||
return Siterator(self._tree)
|
||||
|
||||
def __str__(self) -> str:
|
||||
output = ''
|
||||
for i in self:
|
||||
output += str(i) + '\n'
|
||||
return output
|
||||
|
||||
def show_tree(self, only_leaves=False):
|
||||
if only_leaves:
|
||||
print(self.__print_tree(self._tree, only_leaves=True))
|
||||
else:
|
||||
print(self)
|
||||
|
||||
def __str__(self):
|
||||
return self.__print_tree(self._tree)
|
||||
|
||||
def _save_datasets(self, tree: Snode, catalog: typing.TextIO, number: int):
|
||||
"""Save the dataset of the node in a csv file
|
||||
|
||||
@@ -232,4 +219,4 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""Save the every dataset stored in the tree to check with manual classifier
|
||||
"""
|
||||
with open(self.get_catalog_name(), 'w', encoding='utf-8') as catalog:
|
||||
self._save_datasets(self._tree, catalog, 1)
|
||||
self._save_datasets(self._tree, catalog, 1)
|
||||
|
Reference in New Issue
Block a user