Solve weak classifier

This commit is contained in:
2021-05-10 12:06:47 +02:00
parent 34bf539fa3
commit acf10d8d7f
2 changed files with 25 additions and 21 deletions

View File

@@ -155,6 +155,10 @@ class Siterator:
self._stack = [] self._stack = []
self._push(tree) self._push(tree)
def __iter__(self):
# To complete the iterator interface
return self
def _push(self, node: Snode): def _push(self, node: Snode):
if node is not None: if node is not None:
self._stack.append(node) self._stack.append(node)

View File

@@ -161,7 +161,7 @@ class Stree_test(unittest.TestCase):
random_state=self._random_state, random_state=self._random_state,
) )
clf.fit(*load_dataset(self._random_state)) clf.fit(*load_dataset(self._random_state))
for node in clf: for node in iter(clf):
computed.append(str(node)) computed.append(str(node))
expected_string += str(node) + "\n" expected_string += str(node) + "\n"
self.assertListEqual(expected, computed) self.assertListEqual(expected, computed)
@@ -242,28 +242,28 @@ class Stree_test(unittest.TestCase):
} }
outcomes = { outcomes = {
"Synt": { "Synt": {
"max_samples liblinear": 0.9606666666666667, "max_samples liblinear": 0.9493333333333334,
"max_samples linear": 0.9486666666666667, "max_samples linear": 0.9426666666666667,
"max_samples rbf": 0.978, "max_samples rbf": 0.9606666666666667,
"max_samples poly": 0.96, "max_samples poly": 0.9373333333333334,
"max_samples sigmoid": 0.908, "max_samples sigmoid": 0.824,
"impurity liblinear": 0.9606666666666667, "impurity liblinear": 0.9493333333333334,
"impurity linear": 0.9486666666666667, "impurity linear": 0.9426666666666667,
"impurity rbf": 0.978, "impurity rbf": 0.9606666666666667,
"impurity poly": 0.96, "impurity poly": 0.9373333333333334,
"impurity sigmoid": 0.908, "impurity sigmoid": 0.824,
}, },
"Iris": { "Iris": {
"max_samples liblinear": 1.0, "max_samples liblinear": 0.9550561797752809,
"max_samples linear": 1.0, "max_samples linear": 1.0,
"max_samples rbf": 0.7808988764044944, "max_samples rbf": 0.6685393258426966,
"max_samples poly": 0.8202247191011236, "max_samples poly": 0.6853932584269663,
"max_samples sigmoid": 0.7528089887640449, "max_samples sigmoid": 0.6404494382022472,
"impurity liblinear": 1.0, "impurity liblinear": 0.9550561797752809,
"impurity linear": 1.0, "impurity linear": 1.0,
"impurity rbf": 0.7808988764044944, "impurity rbf": 0.6685393258426966,
"impurity poly": 0.8202247191011236, "impurity poly": 0.6853932584269663,
"impurity sigmoid": 0.7528089887640449, "impurity sigmoid": 0.6404494382022472,
}, },
} }
@@ -272,8 +272,7 @@ class Stree_test(unittest.TestCase):
for criteria in ["max_samples", "impurity"]: for criteria in ["max_samples", "impurity"]:
for kernel in self._kernels: for kernel in self._kernels:
clf = Stree( clf = Stree(
C=55, max_iter=1e4,
max_iter=1e5,
multiclass_strategy="ovr" multiclass_strategy="ovr"
if kernel == "liblinear" if kernel == "liblinear"
else "ovo", else "ovo",
@@ -286,6 +285,7 @@ class Stree_test(unittest.TestCase):
self.assertAlmostEqual( self.assertAlmostEqual(
outcome, outcome,
clf.score(px, py), clf.score(px, py),
5,
f"{name} - {criteria} - {kernel}", f"{name} - {criteria} - {kernel}",
) )