Add incompatible hyparams to be_main

This commit is contained in:
2022-11-20 19:10:28 +01:00
parent 1b8a424ad3
commit 3ade3f4022
4 changed files with 64 additions and 22 deletions

View File

@@ -55,13 +55,13 @@ class Arguments(argparse.ArgumentParser):
self._overrides = {} self._overrides = {}
self._subparser = None self._subparser = None
self.parameters = { self.parameters = {
"best": [ "best_paramfile": [
("-b", "--best"), ("-b", "--best_paramfile"),
{ {
"required": False,
"action": "store_true", "action": "store_true",
"required": False,
"default": False, "default": False,
"help": "best results of models", "help": "Use best hyperparams file?",
}, },
], ],
"color": [ "color": [
@@ -107,7 +107,7 @@ class Arguments(argparse.ArgumentParser):
"required": False, "required": False,
"action": "store_true", "action": "store_true",
"default": False, "default": False,
"help": "Use best hyperparams file?", "help": "Use grid output hyperparams file?",
}, },
], ],
"hidden": [ "hidden": [
@@ -198,15 +198,6 @@ class Arguments(argparse.ArgumentParser):
"help": "number of folds", "help": "number of folds",
}, },
], ],
"paramfile": [
("-f", "--paramfile"),
{
"action": "store_true",
"required": False,
"default": False,
"help": "Use best hyperparams file?",
},
],
"platform": [ "platform": [
("-P", "--platform"), ("-P", "--platform"),
{ {
@@ -314,6 +305,12 @@ class Arguments(argparse.ArgumentParser):
# Order of args is important # Order of args is important
parser.add_argument(*names, **{**args, **parameters}) parser.add_argument(*names, **{**args, **parameters})
def add_exclusive(self, hyperparameters, required=False):
group = self.add_mutually_exclusive_group(required=required)
for name in hyperparameters:
names, parameters = self.parameters[name]
group.add_argument(*names, **parameters)
def parse(self, args=None): def parse(self, args=None):
for key, (dest_key, value) in self._overrides.items(): for key, (dest_key, value) in self._overrides.items():
if args is None: if args is None:

View File

@@ -10,18 +10,20 @@ from benchmark.Arguments import Arguments
def main(args_test=None): def main(args_test=None):
arguments = Arguments() arguments = Arguments(prog="be_main")
arguments.xset("stratified").xset("score").xset("model", mandatory=True) arguments.xset("stratified").xset("score").xset("model", mandatory=True)
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title") arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
arguments.xset("hyperparameters").xset("paramfile").xset("report") arguments.xset("report")
arguments.xset("grid_paramfile") arguments.add_exclusive(
["grid_paramfile", "best_paramfile", "hyperparameters"]
)
arguments.xset( arguments.xset(
"dataset", overrides="title", const="Test with only one dataset" "dataset", overrides="title", const="Test with only one dataset"
) )
args = arguments.parse(args_test) args = arguments.parse(args_test)
report = args.report or args.dataset is not None report = args.report or args.dataset is not None
if args.grid_paramfile: if args.grid_paramfile:
args.paramfile = False args.best_paramfile = False
try: try:
job = Experiment( job = Experiment(
score_name=args.score, score_name=args.score,
@@ -29,7 +31,7 @@ def main(args_test=None):
stratified=args.stratified, stratified=args.stratified,
datasets=Datasets(dataset_name=args.dataset), datasets=Datasets(dataset_name=args.dataset),
hyperparams_dict=args.hyperparameters, hyperparams_dict=args.hyperparameters,
hyperparams_file=args.paramfile, hyperparams_file=args.best_paramfile,
grid_paramfile=args.grid_paramfile, grid_paramfile=args.grid_paramfile,
progress_bar=not args.quiet, progress_bar=not args.quiet,
platform=args.platform, platform=args.platform,

View File

@@ -24,6 +24,7 @@ class ArgumentsTest(TestBase):
def test_parameters(self): def test_parameters(self):
expected_parameters = { expected_parameters = {
"best_paramfile": ("-b", "--best_paramfile"),
"color": ("-c", "--color"), "color": ("-c", "--color"),
"compare": ("-c", "--compare"), "compare": ("-c", "--compare"),
"dataset": ("-d", "--dataset"), "dataset": ("-d", "--dataset"),
@@ -39,7 +40,6 @@ class ArgumentsTest(TestBase):
"nan": ("--nan",), "nan": ("--nan",),
"number": ("-n", "--number"), "number": ("-n", "--number"),
"n_folds": ("-n", "--n_folds"), "n_folds": ("-n", "--n_folds"),
"paramfile": ("-f", "--paramfile"),
"platform": ("-P", "--platform"), "platform": ("-P", "--platform"),
"quiet": ("-q", "--quiet"), "quiet": ("-q", "--quiet"),
"report": ("-r", "--report"), "report": ("-r", "--report"),

View File

@@ -1,4 +1,5 @@
import os import os
import json
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from ...Results import Report from ...Results import Report
@@ -66,7 +67,7 @@ class BeMainTest(TestBase):
"STree", "STree",
"--title", "--title",
"test", "test",
"-f", "-b",
"-r", "-r",
], ],
) )
@@ -77,6 +78,48 @@ class BeMainTest(TestBase):
stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14] stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
) )
@patch("sys.stdout", new_callable=StringIO)
@patch("sys.stderr", new_callable=StringIO)
def test_be_main_incompatible_params(self, stdout, stderr):
m1 = (
"be_main: error: argument -b/--best_paramfile: not allowed with "
"argument -p/--hyperparameters"
)
m2 = (
"be_main: error: argument -g/--grid_paramfile: not allowed with "
"argument -p/--hyperparameters"
)
m3 = (
"be_main: error: argument -g/--grid_paramfile: not allowed with "
"argument -p/--hyperparameters"
)
m4 = m1
p0 = [
"-s",
self.score,
"-m",
"SVC",
"--title",
"test",
]
pset = json.dumps(dict(C=17))
p1 = p0.copy()
p1.extend(["-p", pset, "-b"])
p2 = p0.copy()
p2.extend(["-p", pset, "-g"])
p3 = p0.copy()
p3.extend(["-p", pset, "-g", "-b"])
p4 = p0.copy()
p4.extend(["-b", "-g"])
parameters = [(p1, m1), (p2, m2), (p3, m3), (p4, m4)]
for parameter, message in parameters:
with self.assertRaises(SystemExit) as msg:
module = self.search_script("be_main")
module.main(parameter)
self.assertEqual(msg.exception.code, 2)
self.assertEqual(stderr.getvalue(), "")
self.assertRegexpMatches(stdout.getvalue(), message)
def test_be_main_best_params_non_existent(self): def test_be_main_best_params_non_existent(self):
model = "GBC" model = "GBC"
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
@@ -88,7 +131,7 @@ class BeMainTest(TestBase):
model, model,
"--title", "--title",
"test", "test",
"-f", "-b",
"-r", "-r",
], ],
) )