From ca96d0512443e465167f00d36103b8ea5145dcaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 9 May 2022 01:34:25 +0200 Subject: [PATCH] Complete be_print_strees --- benchmark/scripts/be_print_strees.py | 24 +++---- .../tests/scripts/Be_Print_Strees_test.py | 64 +++++++++++++------ 2 files changed, 53 insertions(+), 35 deletions(-) diff --git a/benchmark/scripts/be_print_strees.py b/benchmark/scripts/be_print_strees.py index a6a369e..5122cbe 100755 --- a/benchmark/scripts/be_print_strees.py +++ b/benchmark/scripts/be_print_strees.py @@ -8,12 +8,6 @@ from benchmark.Utils import Files, Folders from benchmark.Arguments import Arguments -def compute_stree(X, y, random_state): - clf = Stree(random_state=random_state) - clf.fit(X, y) - return clf - - def load_hyperparams(score_name, model_name): grid_file = os.path.join( Folders.results, Files.grid_output(score_name, model_name) @@ -22,13 +16,13 @@ def load_hyperparams(score_name, model_name): return json.load(f) -def hyperparam_filter(hyperparams): - res = {} - for key, value in hyperparams.items(): - if key.startswith("base_estimator"): - newkey = key.split("__")[1] - res[newkey] = value - return res +# def hyperparam_filter(hyperparams): +# res = {} +# for key, value in hyperparams.items(): +# if key.startswith("base_estimator"): +# newkey = key.split("__")[1] +# res[newkey] = value +# return res def build_title(dataset, accuracy, n_samples, n_features, n_classes, nodes): @@ -89,9 +83,7 @@ def main(args_test=None): if dataset == args.dataset or args.dataset == "all": X, y = dt.load(dataset) clf = Stree(random_state=random_state) - hyperparams_dataset = hyperparam_filter( - hyperparameters[dataset][1] - ) + hyperparams_dataset = hyperparameters[dataset][1] clf.set_params(**hyperparams_dataset) clf.fit(X, y) print_stree(clf, dataset, X, y, args.color, args.quiet) diff --git a/benchmark/tests/scripts/Be_Print_Strees_test.py b/benchmark/tests/scripts/Be_Print_Strees_test.py index f85b625..bc9ead0 100644 --- a/benchmark/tests/scripts/Be_Print_Strees_test.py +++ b/benchmark/tests/scripts/Be_Print_Strees_test.py @@ -1,34 +1,60 @@ -import shutil -from io import StringIO -from unittest.mock import patch -from ...Results import Report -from ...Utils import Files +import os +import hashlib +from ...Utils import Folders from ..TestBase import TestBase class BePrintStrees(TestBase): def setUp(self): self.prepare_scripts_env() - source = Files.grid_output("accuracy", "STree") - target = Files.grid_output("accuracy", "STree") - shutil.copy(source, target) self.score = "accuracy" - self.files = [target] + self.files = [] + self.datasets = ["balloons", "balance-scale"] + self.expected = { + "balloons": { + "color": "b2342cc27a4ab495970616346bedf73b", + "gray": "a9bc4d2041f2869a93164a548f6ad986", + }, + "balance-scale": { + "color": "2e85d66de1ae838d01a3f327397a50c8", + "gray": "30f325134d4b5153c9e6ecbcae7b6d1f", + }, + } def tearDown(self) -> None: self.remove_files(self.files, ".") return super().tearDown() - # def test_be_benchmark_dataset(self): - # stdout, _ = self.execute_script( - # "be_main", - # ["-m", "STree", "-d", "balloons", "--title", "test"], - # ) - # self.check_output_lines( - # stdout=stdout, - # file_name="be_main_dataset", - # lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13], - # ) + def hash_file(self, name): + file_name = os.path.join(Folders.img, f"{name}.png") + self.files.append(file_name) + self.assertTrue(os.path.exists(file_name)) + with open(file_name, "rb") as f: + return hashlib.md5(f.read()).hexdigest(), file_name + + def test_be_print_strees_dataset_bn(self): + for name in self.datasets: + stdout, _ = self.execute_script( + "be_print_strees", + ["-d", name, "-q", "1"], + ) + computed_hash, file_name = self.hash_file(f"stree_{name}") + self.assertEqual( + stdout.getvalue(), f"File {file_name} generated\n" + ) + self.assertEqual(computed_hash, self.expected[name]["gray"]) + + def test_be_print_strees_dataset_color(self): + for name in self.datasets: + stdout, _ = self.execute_script( + "be_print_strees", + ["-d", name, "-q", "1", "-c", "1"], + ) + computed_hash, file_name = self.hash_file(f"stree_{name}") + self.assertEqual( + stdout.getvalue(), f"File {file_name} generated\n" + ) + self.assertEqual(computed_hash, self.expected[name]["color"]) # def test_be_benchmark_complete(self): # stdout, _ = self.execute_script(