diff --git a/stree/Strees.py b/stree/Strees.py index 32c9ff2..8fa13d2 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -155,6 +155,10 @@ class Siterator: self._stack = [] self._push(tree) + def __iter__(self): + # To complete the iterator interface + return self + def _push(self, node: Snode): if node is not None: self._stack.append(node) diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 803ad3c..42872db 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -161,7 +161,7 @@ class Stree_test(unittest.TestCase): random_state=self._random_state, ) clf.fit(*load_dataset(self._random_state)) - for node in clf: + for node in iter(clf): computed.append(str(node)) expected_string += str(node) + "\n" self.assertListEqual(expected, computed) @@ -242,28 +242,28 @@ class Stree_test(unittest.TestCase): } outcomes = { "Synt": { - "max_samples liblinear": 0.9606666666666667, - "max_samples linear": 0.9486666666666667, - "max_samples rbf": 0.978, - "max_samples poly": 0.96, - "max_samples sigmoid": 0.908, - "impurity liblinear": 0.9606666666666667, - "impurity linear": 0.9486666666666667, - "impurity rbf": 0.978, - "impurity poly": 0.96, - "impurity sigmoid": 0.908, + "max_samples liblinear": 0.9493333333333334, + "max_samples linear": 0.9426666666666667, + "max_samples rbf": 0.9606666666666667, + "max_samples poly": 0.9373333333333334, + "max_samples sigmoid": 0.824, + "impurity liblinear": 0.9493333333333334, + "impurity linear": 0.9426666666666667, + "impurity rbf": 0.9606666666666667, + "impurity poly": 0.9373333333333334, + "impurity sigmoid": 0.824, }, "Iris": { - "max_samples liblinear": 1.0, + "max_samples liblinear": 0.9550561797752809, "max_samples linear": 1.0, - "max_samples rbf": 0.7808988764044944, - "max_samples poly": 0.8202247191011236, - "max_samples sigmoid": 0.7528089887640449, - "impurity liblinear": 1.0, + "max_samples rbf": 0.6685393258426966, + "max_samples poly": 0.6853932584269663, + "max_samples sigmoid": 0.6404494382022472, + "impurity liblinear": 0.9550561797752809, "impurity linear": 1.0, - "impurity rbf": 0.7808988764044944, - "impurity poly": 0.8202247191011236, - "impurity sigmoid": 0.7528089887640449, + "impurity rbf": 0.6685393258426966, + "impurity poly": 0.6853932584269663, + "impurity sigmoid": 0.6404494382022472, }, } @@ -272,8 +272,7 @@ class Stree_test(unittest.TestCase): for criteria in ["max_samples", "impurity"]: for kernel in self._kernels: clf = Stree( - C=55, - max_iter=1e5, + max_iter=1e4, multiclass_strategy="ovr" if kernel == "liblinear" else "ovo", @@ -286,6 +285,7 @@ class Stree_test(unittest.TestCase): self.assertAlmostEqual( outcome, clf.score(px, py), + 5, f"{name} - {criteria} - {kernel}", )