diff --git a/src/print_strees.py b/src/print_strees old mode 100644 new mode 100755 similarity index 57% rename from src/print_strees.py rename to src/print_strees index fb21962..98dd335 --- a/src/print_strees.py +++ b/src/print_strees @@ -1,3 +1,4 @@ +#!/usr/bin/env python import os import subprocess import argparse @@ -18,8 +19,16 @@ def parse_arguments(): default=False, help="use colors for the tree", ) + ap.add_argument( + "-d", + "--dataset", + type=str, + required=False, + default="all", + help="dataset to print or all", + ) args = ap.parse_args() - return (args.color,) + return (args.color, args.dataset) def compute_stree(X, y, random_state): @@ -54,14 +63,37 @@ def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes): ) -def print_stree(clf, dataset, X, y): +def add_color(source): + return ( + source.replace( # Background and title font color + "fontcolor=blue", "fontcolor=white\nbgcolor=darkslateblue" + ) + .replace("brown", "cyan") # subtitle font color + .replace( # Fill leaves + "style=filled", 'style="filled" fillcolor="/blues5/1:/blues5/4"' + ) + .replace( # Fill nodes + "fontcolor=black", + 'style=radial fillcolor="orange:white" gradientangle=60', + ) + .replace("color=black", "color=white") # arrow color + .replace( # accuracy / # nodes + 'color="red"', 'color="darkolivegreen1"' + ) + ) + + +def print_stree(clf, dataset, X, y, color): output_folder = "img" samples, features = X.shape classes = max(y) + 1 accuracy = clf.score(X, y) nodes, _ = clf.nodes_leaves() title = build_title(dataset, accuracy, samples, features, classes, nodes) - grp = Source(clf.graph(title)) + dot_source = clf.graph(title) + if color: + dot_source = add_color(dot_source) + grp = Source(dot_source) file_name = os.path.join(output_folder, f"stree_{dataset}") grp.render(format="png", filename=f"{file_name}") os.remove(f"{file_name}") @@ -72,14 +104,17 @@ def print_stree(clf, dataset, X, y): if __name__ == "__main__": - (color,) = parse_arguments() + (color, dataset_chosen) = parse_arguments() hyperparameters = load_hyperparams("accuracy", "ODTE") random_state = 57 dt = Datasets() for dataset in dt: - X, y = dt.load(dataset) - clf = Stree(random_state=random_state) - hyperparams_dataset = hyperparam_filter(hyperparameters[dataset][1]) - clf.set_params(**hyperparams_dataset) - clf.fit(X, y) - print_stree(clf, dataset, X, y) + if dataset == dataset_chosen or dataset_chosen == "all": + X, y = dt.load(dataset) + clf = Stree(random_state=random_state) + hyperparams_dataset = hyperparam_filter( + hyperparameters[dataset][1] + ) + clf.set_params(**hyperparams_dataset) + clf.fit(X, y) + print_stree(clf, dataset, X, y, color)