diff --git a/stree/Splitter.py b/stree/Splitter.py index e9ac169..4924721 100644 --- a/stree/Splitter.py +++ b/stree/Splitter.py @@ -154,19 +154,17 @@ class Snode: if self.is_leaf(): output += ( f'N{id(self)} [shape=box style=filled label="' - f"class={self._class} belief={self._belief: .3f} " - f"impurity={self._impurity:.3f} " - f'classes/samples={count_values}"];\n' + f"class={self._class} impurity={self._impurity:.3f} " + f'classes={count_values[0]} samples={count_values[1]}"];\n' ) else: output += ( f'N{id(self)} [label="#features={len(self._features)} ' - f'classes/samples={count_values}"];\n' - ) - output += f'N{id(self)} -> N{id(self.get_up())} [label="Up"];\n' - output += ( - f'N{id(self)} -> N{id(self.get_down())} [label="Down"];\n' + f"classes={count_values[0]} samples={count_values[1]} " + f'({sum(count_values[1])})"];\n' ) + output += f"N{id(self)} -> N{id(self.get_up())};\n" + output += f"N{id(self)} -> N{id(self.get_down())};\n" return output def __str__(self) -> str: diff --git a/stree/Strees.py b/stree/Strees.py index 97dc181..64329e3 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -476,7 +476,7 @@ class Stree(BaseEstimator, ClassifierMixin): tree = None return Siterator(tree) - def graph(self) -> str: + def graph(self, title="") -> str: """Graphviz code representing the tree Returns @@ -484,7 +484,10 @@ class Stree(BaseEstimator, ClassifierMixin): str graphviz code """ - output = "digraph STree {\n" + output = ( + "digraph STree {\nlabel=\nfontsize=30\nfontcolor=blue\nlabelloc=t\n" + ) for node in self: output += node.graph() output += "}\n" diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 95ca945..63c210f 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -688,17 +688,40 @@ class Stree_test(unittest.TestCase): """Check graphviz representation of the tree.""" X, y = load_wine(return_X_y=True) clf = Stree(random_state=self._random_state) - self.assertEqual(clf.graph(), "digraph STree {\n}\n") - clf.fit(X, y) - expected_head = "digraph STree {\n" - expected_tail = ( - ' [shape=box style=filled label="class=1 belief= ' - '1.000 impurity=0.000 classes/samples=(array([1]), array([1]))"]' - ";\n}\n" + + expected_head = ( + "digraph STree {\nlabel=\nfontsize=30\n" + "fontcolor=blue\nlabelloc=t\n" ) + expected_tail = ( + ' [shape=box style=filled label="class=1 impurity=0.000 ' + 'classes=[1] samples=[1]"];\n}\n' + ) + self.assertEqual(clf.graph(), expected_head + "}\n") + clf.fit(X, y) computed = clf.graph() computed_head = computed[: len(expected_head)] num = -len(expected_tail) computed_tail = computed[num:] self.assertEqual(computed_head, expected_head) self.assertEqual(computed_tail, expected_tail) + + def test_graph_title(self): + X, y = load_wine(return_X_y=True) + clf = Stree(random_state=self._random_state) + expected_head = ( + "digraph STree {\nlabel=\nfontsize=30\n" + "fontcolor=blue\nlabelloc=t\n" + ) + expected_tail = ( + ' [shape=box style=filled label="class=1 impurity=0.000 ' + 'classes=[1] samples=[1]"];\n}\n' + ) + self.assertEqual(clf.graph("Sample title"), expected_head + "}\n") + clf.fit(X, y) + computed = clf.graph("Sample title") + computed_head = computed[: len(expected_head)] + num = -len(expected_tail) + computed_tail = computed[num:] + self.assertEqual(computed_head, expected_head) + self.assertEqual(computed_tail, expected_tail)