mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-16 07:55:54 +00:00
Add quiet argument to generate only the trees
This commit is contained in:
@@ -27,8 +27,16 @@ def parse_arguments():
|
|||||||
default="all",
|
default="all",
|
||||||
help="dataset to print or all",
|
help="dataset to print or all",
|
||||||
)
|
)
|
||||||
|
ap.add_argument(
|
||||||
|
"-q",
|
||||||
|
"--quiet",
|
||||||
|
type=bool,
|
||||||
|
required=False,
|
||||||
|
default=False,
|
||||||
|
help="don't print generated tree(s)",
|
||||||
|
)
|
||||||
args = ap.parse_args()
|
args = ap.parse_args()
|
||||||
return (args.color, args.dataset)
|
return (args.color, args.dataset, args.quiet)
|
||||||
|
|
||||||
|
|
||||||
def compute_stree(X, y, random_state):
|
def compute_stree(X, y, random_state):
|
||||||
@@ -83,7 +91,7 @@ def add_color(source):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_stree(clf, dataset, X, y, color):
|
def print_stree(clf, dataset, X, y, color, quiet):
|
||||||
output_folder = "img"
|
output_folder = "img"
|
||||||
samples, features = X.shape
|
samples, features = X.shape
|
||||||
classes = max(y) + 1
|
classes = max(y) + 1
|
||||||
@@ -98,13 +106,14 @@ def print_stree(clf, dataset, X, y, color):
|
|||||||
grp.render(format="png", filename=f"{file_name}")
|
grp.render(format="png", filename=f"{file_name}")
|
||||||
os.remove(f"{file_name}")
|
os.remove(f"{file_name}")
|
||||||
print(f"File {file_name}.png generated")
|
print(f"File {file_name}.png generated")
|
||||||
cmd_open = "/usr/bin/open"
|
if not quiet:
|
||||||
if os.path.isfile(cmd_open) and os.access(cmd_open, os.X_OK):
|
cmd_open = "/usr/bin/open"
|
||||||
subprocess.run([cmd_open, f"{file_name}.png"])
|
if os.path.isfile(cmd_open) and os.access(cmd_open, os.X_OK):
|
||||||
|
subprocess.run([cmd_open, f"{file_name}.png"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
(color, dataset_chosen) = parse_arguments()
|
(color, dataset_chosen, quiet) = parse_arguments()
|
||||||
hyperparameters = load_hyperparams("accuracy", "ODTE")
|
hyperparameters = load_hyperparams("accuracy", "ODTE")
|
||||||
random_state = 57
|
random_state = 57
|
||||||
dt = Datasets()
|
dt = Datasets()
|
||||||
@@ -117,4 +126,4 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
clf.set_params(**hyperparams_dataset)
|
clf.set_params(**hyperparams_dataset)
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
print_stree(clf, dataset, X, y, color)
|
print_stree(clf, dataset, X, y, color, quiet)
|
||||||
|
Reference in New Issue
Block a user