Add incompatible hyparams to be_main

This commit is contained in:
2022-11-20 19:10:28 +01:00
parent 1b8a424ad3
commit 3ade3f4022
4 changed files with 64 additions and 22 deletions

View File

@@ -55,13 +55,13 @@ class Arguments(argparse.ArgumentParser):
self._overrides = {}
self._subparser = None
self.parameters = {
"best": [
("-b", "--best"),
"best_paramfile": [
("-b", "--best_paramfile"),
{
"required": False,
"action": "store_true",
"required": False,
"default": False,
"help": "best results of models",
"help": "Use best hyperparams file?",
},
],
"color": [
@@ -107,7 +107,7 @@ class Arguments(argparse.ArgumentParser):
"required": False,
"action": "store_true",
"default": False,
"help": "Use best hyperparams file?",
"help": "Use grid output hyperparams file?",
},
],
"hidden": [
@@ -198,15 +198,6 @@ class Arguments(argparse.ArgumentParser):
"help": "number of folds",
},
],
"paramfile": [
("-f", "--paramfile"),
{
"action": "store_true",
"required": False,
"default": False,
"help": "Use best hyperparams file?",
},
],
"platform": [
("-P", "--platform"),
{
@@ -314,6 +305,12 @@ class Arguments(argparse.ArgumentParser):
# Order of args is important
parser.add_argument(*names, **{**args, **parameters})
def add_exclusive(self, hyperparameters, required=False):
group = self.add_mutually_exclusive_group(required=required)
for name in hyperparameters:
names, parameters = self.parameters[name]
group.add_argument(*names, **parameters)
def parse(self, args=None):
for key, (dest_key, value) in self._overrides.items():
if args is None:

View File

@@ -10,18 +10,20 @@ from benchmark.Arguments import Arguments
def main(args_test=None):
arguments = Arguments()
arguments = Arguments(prog="be_main")
arguments.xset("stratified").xset("score").xset("model", mandatory=True)
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
arguments.xset("hyperparameters").xset("paramfile").xset("report")
arguments.xset("grid_paramfile")
arguments.xset("report")
arguments.add_exclusive(
["grid_paramfile", "best_paramfile", "hyperparameters"]
)
arguments.xset(
"dataset", overrides="title", const="Test with only one dataset"
)
args = arguments.parse(args_test)
report = args.report or args.dataset is not None
if args.grid_paramfile:
args.paramfile = False
args.best_paramfile = False
try:
job = Experiment(
score_name=args.score,
@@ -29,7 +31,7 @@ def main(args_test=None):
stratified=args.stratified,
datasets=Datasets(dataset_name=args.dataset),
hyperparams_dict=args.hyperparameters,
hyperparams_file=args.paramfile,
hyperparams_file=args.best_paramfile,
grid_paramfile=args.grid_paramfile,
progress_bar=not args.quiet,
platform=args.platform,

View File

@@ -24,6 +24,7 @@ class ArgumentsTest(TestBase):
def test_parameters(self):
expected_parameters = {
"best_paramfile": ("-b", "--best_paramfile"),
"color": ("-c", "--color"),
"compare": ("-c", "--compare"),
"dataset": ("-d", "--dataset"),
@@ -39,7 +40,6 @@ class ArgumentsTest(TestBase):
"nan": ("--nan",),
"number": ("-n", "--number"),
"n_folds": ("-n", "--n_folds"),
"paramfile": ("-f", "--paramfile"),
"platform": ("-P", "--platform"),
"quiet": ("-q", "--quiet"),
"report": ("-r", "--report"),

View File

@@ -1,4 +1,5 @@
import os
import json
from io import StringIO
from unittest.mock import patch
from ...Results import Report
@@ -66,7 +67,7 @@ class BeMainTest(TestBase):
"STree",
"--title",
"test",
"-f",
"-b",
"-r",
],
)
@@ -77,6 +78,48 @@ class BeMainTest(TestBase):
stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
)
@patch("sys.stdout", new_callable=StringIO)
@patch("sys.stderr", new_callable=StringIO)
def test_be_main_incompatible_params(self, stdout, stderr):
m1 = (
"be_main: error: argument -b/--best_paramfile: not allowed with "
"argument -p/--hyperparameters"
)
m2 = (
"be_main: error: argument -g/--grid_paramfile: not allowed with "
"argument -p/--hyperparameters"
)
m3 = (
"be_main: error: argument -g/--grid_paramfile: not allowed with "
"argument -p/--hyperparameters"
)
m4 = m1
p0 = [
"-s",
self.score,
"-m",
"SVC",
"--title",
"test",
]
pset = json.dumps(dict(C=17))
p1 = p0.copy()
p1.extend(["-p", pset, "-b"])
p2 = p0.copy()
p2.extend(["-p", pset, "-g"])
p3 = p0.copy()
p3.extend(["-p", pset, "-g", "-b"])
p4 = p0.copy()
p4.extend(["-b", "-g"])
parameters = [(p1, m1), (p2, m2), (p3, m3), (p4, m4)]
for parameter, message in parameters:
with self.assertRaises(SystemExit) as msg:
module = self.search_script("be_main")
module.main(parameter)
self.assertEqual(msg.exception.code, 2)
self.assertEqual(stderr.getvalue(), "")
self.assertRegexpMatches(stdout.getvalue(), message)
def test_be_main_best_params_non_existent(self):
model = "GBC"
stdout, stderr = self.execute_script(
@@ -88,7 +131,7 @@ class BeMainTest(TestBase):
model,
"--title",
"test",
"-f",
"-b",
"-r",
],
)