mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-15 23:45:54 +00:00
Add color option to print_strees
This commit is contained in:
55
src/print_strees.py → src/print_strees
Normal file → Executable file
55
src/print_strees.py → src/print_strees
Normal file → Executable file
@@ -1,3 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import subprocess
|
||||
import argparse
|
||||
@@ -18,8 +19,16 @@ def parse_arguments():
|
||||
default=False,
|
||||
help="use colors for the tree",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
default="all",
|
||||
help="dataset to print or all",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
return (args.color,)
|
||||
return (args.color, args.dataset)
|
||||
|
||||
|
||||
def compute_stree(X, y, random_state):
|
||||
@@ -54,14 +63,37 @@ def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes):
|
||||
)
|
||||
|
||||
|
||||
def print_stree(clf, dataset, X, y):
|
||||
def add_color(source):
|
||||
return (
|
||||
source.replace( # Background and title font color
|
||||
"fontcolor=blue", "fontcolor=white\nbgcolor=darkslateblue"
|
||||
)
|
||||
.replace("brown", "cyan") # subtitle font color
|
||||
.replace( # Fill leaves
|
||||
"style=filled", 'style="filled" fillcolor="/blues5/1:/blues5/4"'
|
||||
)
|
||||
.replace( # Fill nodes
|
||||
"fontcolor=black",
|
||||
'style=radial fillcolor="orange:white" gradientangle=60',
|
||||
)
|
||||
.replace("color=black", "color=white") # arrow color
|
||||
.replace( # accuracy / # nodes
|
||||
'color="red"', 'color="darkolivegreen1"'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def print_stree(clf, dataset, X, y, color):
|
||||
output_folder = "img"
|
||||
samples, features = X.shape
|
||||
classes = max(y) + 1
|
||||
accuracy = clf.score(X, y)
|
||||
nodes, _ = clf.nodes_leaves()
|
||||
title = build_title(dataset, accuracy, samples, features, classes, nodes)
|
||||
grp = Source(clf.graph(title))
|
||||
dot_source = clf.graph(title)
|
||||
if color:
|
||||
dot_source = add_color(dot_source)
|
||||
grp = Source(dot_source)
|
||||
file_name = os.path.join(output_folder, f"stree_{dataset}")
|
||||
grp.render(format="png", filename=f"{file_name}")
|
||||
os.remove(f"{file_name}")
|
||||
@@ -72,14 +104,17 @@ def print_stree(clf, dataset, X, y):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(color,) = parse_arguments()
|
||||
(color, dataset_chosen) = parse_arguments()
|
||||
hyperparameters = load_hyperparams("accuracy", "ODTE")
|
||||
random_state = 57
|
||||
dt = Datasets()
|
||||
for dataset in dt:
|
||||
X, y = dt.load(dataset)
|
||||
clf = Stree(random_state=random_state)
|
||||
hyperparams_dataset = hyperparam_filter(hyperparameters[dataset][1])
|
||||
clf.set_params(**hyperparams_dataset)
|
||||
clf.fit(X, y)
|
||||
print_stree(clf, dataset, X, y)
|
||||
if dataset == dataset_chosen or dataset_chosen == "all":
|
||||
X, y = dt.load(dataset)
|
||||
clf = Stree(random_state=random_state)
|
||||
hyperparams_dataset = hyperparam_filter(
|
||||
hyperparameters[dataset][1]
|
||||
)
|
||||
clf.set_params(**hyperparams_dataset)
|
||||
clf.fit(X, y)
|
||||
print_stree(clf, dataset, X, y, color)
|
Reference in New Issue
Block a user