From 3ade3f402239cb47c3cbf7fa16196419625c1b91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Sun, 20 Nov 2022 19:10:28 +0100 Subject: [PATCH] Add incompatible hyparams to be_main --- benchmark/Arguments.py | 25 ++++++------- benchmark/scripts/be_main.py | 12 ++++--- benchmark/tests/Arguments_test.py | 2 +- benchmark/tests/scripts/Be_Main_test.py | 47 +++++++++++++++++++++++-- 4 files changed, 64 insertions(+), 22 deletions(-) diff --git a/benchmark/Arguments.py b/benchmark/Arguments.py index 1c8be8e..33fc88a 100644 --- a/benchmark/Arguments.py +++ b/benchmark/Arguments.py @@ -55,13 +55,13 @@ class Arguments(argparse.ArgumentParser): self._overrides = {} self._subparser = None self.parameters = { - "best": [ - ("-b", "--best"), + "best_paramfile": [ + ("-b", "--best_paramfile"), { - "required": False, "action": "store_true", + "required": False, "default": False, - "help": "best results of models", + "help": "Use best hyperparams file?", }, ], "color": [ @@ -107,7 +107,7 @@ class Arguments(argparse.ArgumentParser): "required": False, "action": "store_true", "default": False, - "help": "Use best hyperparams file?", + "help": "Use grid output hyperparams file?", }, ], "hidden": [ @@ -198,15 +198,6 @@ class Arguments(argparse.ArgumentParser): "help": "number of folds", }, ], - "paramfile": [ - ("-f", "--paramfile"), - { - "action": "store_true", - "required": False, - "default": False, - "help": "Use best hyperparams file?", - }, - ], "platform": [ ("-P", "--platform"), { @@ -314,6 +305,12 @@ class Arguments(argparse.ArgumentParser): # Order of args is important parser.add_argument(*names, **{**args, **parameters}) + def add_exclusive(self, hyperparameters, required=False): + group = self.add_mutually_exclusive_group(required=required) + for name in hyperparameters: + names, parameters = self.parameters[name] + group.add_argument(*names, **parameters) + def parse(self, args=None): for key, (dest_key, value) in self._overrides.items(): if args is None: diff --git a/benchmark/scripts/be_main.py b/benchmark/scripts/be_main.py index 33a3428..2786967 100755 --- a/benchmark/scripts/be_main.py +++ b/benchmark/scripts/be_main.py @@ -10,18 +10,20 @@ from benchmark.Arguments import Arguments def main(args_test=None): - arguments = Arguments() + arguments = Arguments(prog="be_main") arguments.xset("stratified").xset("score").xset("model", mandatory=True) arguments.xset("n_folds").xset("platform").xset("quiet").xset("title") - arguments.xset("hyperparameters").xset("paramfile").xset("report") - arguments.xset("grid_paramfile") + arguments.xset("report") + arguments.add_exclusive( + ["grid_paramfile", "best_paramfile", "hyperparameters"] + ) arguments.xset( "dataset", overrides="title", const="Test with only one dataset" ) args = arguments.parse(args_test) report = args.report or args.dataset is not None if args.grid_paramfile: - args.paramfile = False + args.best_paramfile = False try: job = Experiment( score_name=args.score, @@ -29,7 +31,7 @@ def main(args_test=None): stratified=args.stratified, datasets=Datasets(dataset_name=args.dataset), hyperparams_dict=args.hyperparameters, - hyperparams_file=args.paramfile, + hyperparams_file=args.best_paramfile, grid_paramfile=args.grid_paramfile, progress_bar=not args.quiet, platform=args.platform, diff --git a/benchmark/tests/Arguments_test.py b/benchmark/tests/Arguments_test.py index 2431479..737f1df 100644 --- a/benchmark/tests/Arguments_test.py +++ b/benchmark/tests/Arguments_test.py @@ -24,6 +24,7 @@ class ArgumentsTest(TestBase): def test_parameters(self): expected_parameters = { + "best_paramfile": ("-b", "--best_paramfile"), "color": ("-c", "--color"), "compare": ("-c", "--compare"), "dataset": ("-d", "--dataset"), @@ -39,7 +40,6 @@ class ArgumentsTest(TestBase): "nan": ("--nan",), "number": ("-n", "--number"), "n_folds": ("-n", "--n_folds"), - "paramfile": ("-f", "--paramfile"), "platform": ("-P", "--platform"), "quiet": ("-q", "--quiet"), "report": ("-r", "--report"), diff --git a/benchmark/tests/scripts/Be_Main_test.py b/benchmark/tests/scripts/Be_Main_test.py index 6464174..fabb305 100644 --- a/benchmark/tests/scripts/Be_Main_test.py +++ b/benchmark/tests/scripts/Be_Main_test.py @@ -1,4 +1,5 @@ import os +import json from io import StringIO from unittest.mock import patch from ...Results import Report @@ -66,7 +67,7 @@ class BeMainTest(TestBase): "STree", "--title", "test", - "-f", + "-b", "-r", ], ) @@ -77,6 +78,48 @@ class BeMainTest(TestBase): stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14] ) + @patch("sys.stdout", new_callable=StringIO) + @patch("sys.stderr", new_callable=StringIO) + def test_be_main_incompatible_params(self, stdout, stderr): + m1 = ( + "be_main: error: argument -b/--best_paramfile: not allowed with " + "argument -p/--hyperparameters" + ) + m2 = ( + "be_main: error: argument -g/--grid_paramfile: not allowed with " + "argument -p/--hyperparameters" + ) + m3 = ( + "be_main: error: argument -g/--grid_paramfile: not allowed with " + "argument -p/--hyperparameters" + ) + m4 = m1 + p0 = [ + "-s", + self.score, + "-m", + "SVC", + "--title", + "test", + ] + pset = json.dumps(dict(C=17)) + p1 = p0.copy() + p1.extend(["-p", pset, "-b"]) + p2 = p0.copy() + p2.extend(["-p", pset, "-g"]) + p3 = p0.copy() + p3.extend(["-p", pset, "-g", "-b"]) + p4 = p0.copy() + p4.extend(["-b", "-g"]) + parameters = [(p1, m1), (p2, m2), (p3, m3), (p4, m4)] + for parameter, message in parameters: + with self.assertRaises(SystemExit) as msg: + module = self.search_script("be_main") + module.main(parameter) + self.assertEqual(msg.exception.code, 2) + self.assertEqual(stderr.getvalue(), "") + self.assertRegexpMatches(stdout.getvalue(), message) + def test_be_main_best_params_non_existent(self): model = "GBC" stdout, stderr = self.execute_script( @@ -88,7 +131,7 @@ class BeMainTest(TestBase): model, "--title", "test", - "-f", + "-b", "-r", ], )