Begin print_strees_test

This commit is contained in:
2022-05-09 01:00:51 +02:00
parent 534f32b625
commit b0c94d4983
3 changed files with 12 additions and 8 deletions

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python
import os
import subprocess
import json
from stree import Stree
from graphviz import Source
@@ -74,18 +73,16 @@ def print_stree(clf, dataset, X, y, color, quiet):
file_name = os.path.join(Folders.img, f"stree_{dataset}")
grp.render(format="png", filename=f"{file_name}")
os.remove(f"{file_name}")
print(f"File {file_name}.png generated")
if not quiet:
cmd_open = "/usr/bin/open"
if os.path.isfile(cmd_open) and os.access(cmd_open, os.X_OK):
subprocess.run([cmd_open, f"{file_name}.png"])
file_name += ".png"
print(f"File {file_name} generated")
Files.open(name=file_name, test=quiet)
def main(args_test=None):
arguments = Arguments()
arguments.xset("color").xset("dataset", default="all").xset("quiet")
args = arguments.parse(args_test)
hyperparameters = load_hyperparams("accuracy", "ODTE")
hyperparameters = load_hyperparams("accuracy", "STree")
random_state = 57
dt = Datasets()
for dataset in dt:

2
benchmark/tests/img/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*
!.gitignore

View File

@@ -1,14 +1,19 @@
import shutil
from io import StringIO
from unittest.mock import patch
from ...Results import Report
from ...Utils import Files
from ..TestBase import TestBase
class BePrintStrees(TestBase):
def setUp(self):
self.prepare_scripts_env()
source = Files.grid_output("accuracy", "STree")
target = Files.grid_output("accuracy", "STree")
shutil.copy(source, target)
self.score = "accuracy"
self.files = []
self.files = [target]
def tearDown(self) -> None:
self.remove_files(self.files, ".")