Begin print_strees_test

This commit is contained in:
2022-05-09 01:00:51 +02:00
parent 534f32b625
commit b0c94d4983
3 changed files with 12 additions and 8 deletions

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import os import os
import subprocess
import json import json
from stree import Stree from stree import Stree
from graphviz import Source from graphviz import Source
@@ -74,18 +73,16 @@ def print_stree(clf, dataset, X, y, color, quiet):
file_name = os.path.join(Folders.img, f"stree_{dataset}") file_name = os.path.join(Folders.img, f"stree_{dataset}")
grp.render(format="png", filename=f"{file_name}") grp.render(format="png", filename=f"{file_name}")
os.remove(f"{file_name}") os.remove(f"{file_name}")
print(f"File {file_name}.png generated") file_name += ".png"
if not quiet: print(f"File {file_name} generated")
cmd_open = "/usr/bin/open" Files.open(name=file_name, test=quiet)
if os.path.isfile(cmd_open) and os.access(cmd_open, os.X_OK):
subprocess.run([cmd_open, f"{file_name}.png"])
def main(args_test=None): def main(args_test=None):
arguments = Arguments() arguments = Arguments()
arguments.xset("color").xset("dataset", default="all").xset("quiet") arguments.xset("color").xset("dataset", default="all").xset("quiet")
args = arguments.parse(args_test) args = arguments.parse(args_test)
hyperparameters = load_hyperparams("accuracy", "ODTE") hyperparameters = load_hyperparams("accuracy", "STree")
random_state = 57 random_state = 57
dt = Datasets() dt = Datasets()
for dataset in dt: for dataset in dt:

2
benchmark/tests/img/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
*
!.gitignore

View File

@@ -1,14 +1,19 @@
import shutil
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from ...Results import Report from ...Results import Report
from ...Utils import Files
from ..TestBase import TestBase from ..TestBase import TestBase
class BePrintStrees(TestBase): class BePrintStrees(TestBase):
def setUp(self): def setUp(self):
self.prepare_scripts_env() self.prepare_scripts_env()
source = Files.grid_output("accuracy", "STree")
target = Files.grid_output("accuracy", "STree")
shutil.copy(source, target)
self.score = "accuracy" self.score = "accuracy"
self.files = [] self.files = [target]
def tearDown(self) -> None: def tearDown(self) -> None:
self.remove_files(self.files, ".") self.remove_files(self.files, ".")