Add color option to print_strees

This commit is contained in:
2022-04-05 12:32:31 +02:00
parent b97bcc8f93
commit 73f7f9d7ae

55
src/print_strees.py → src/print_strees Normal file → Executable file
View File

@@ -1,3 +1,4 @@
#!/usr/bin/env python
import os
import subprocess
import argparse
@@ -18,8 +19,16 @@ def parse_arguments():
default=False,
help="use colors for the tree",
)
ap.add_argument(
"-d",
"--dataset",
type=str,
required=False,
default="all",
help="dataset to print or all",
)
args = ap.parse_args()
return (args.color,)
return (args.color, args.dataset)
def compute_stree(X, y, random_state):
@@ -54,14 +63,37 @@ def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes):
)
def print_stree(clf, dataset, X, y):
def add_color(source):
return (
source.replace( # Background and title font color
"fontcolor=blue", "fontcolor=white\nbgcolor=darkslateblue"
)
.replace("brown", "cyan") # subtitle font color
.replace( # Fill leaves
"style=filled", 'style="filled" fillcolor="/blues5/1:/blues5/4"'
)
.replace( # Fill nodes
"fontcolor=black",
'style=radial fillcolor="orange:white" gradientangle=60',
)
.replace("color=black", "color=white") # arrow color
.replace( # accuracy / # nodes
'color="red"', 'color="darkolivegreen1"'
)
)
def print_stree(clf, dataset, X, y, color):
output_folder = "img"
samples, features = X.shape
classes = max(y) + 1
accuracy = clf.score(X, y)
nodes, _ = clf.nodes_leaves()
title = build_title(dataset, accuracy, samples, features, classes, nodes)
grp = Source(clf.graph(title))
dot_source = clf.graph(title)
if color:
dot_source = add_color(dot_source)
grp = Source(dot_source)
file_name = os.path.join(output_folder, f"stree_{dataset}")
grp.render(format="png", filename=f"{file_name}")
os.remove(f"{file_name}")
@@ -72,14 +104,17 @@ def print_stree(clf, dataset, X, y):
if __name__ == "__main__":
(color,) = parse_arguments()
(color, dataset_chosen) = parse_arguments()
hyperparameters = load_hyperparams("accuracy", "ODTE")
random_state = 57
dt = Datasets()
for dataset in dt:
X, y = dt.load(dataset)
clf = Stree(random_state=random_state)
hyperparams_dataset = hyperparam_filter(hyperparameters[dataset][1])
clf.set_params(**hyperparams_dataset)
clf.fit(X, y)
print_stree(clf, dataset, X, y)
if dataset == dataset_chosen or dataset_chosen == "all":
X, y = dt.load(dataset)
clf = Stree(random_state=random_state)
hyperparams_dataset = hyperparam_filter(
hyperparameters[dataset][1]
)
clf.set_params(**hyperparams_dataset)
clf.fit(X, y)
print_stree(clf, dataset, X, y, color)