mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-16 07:55:54 +00:00
Add incompatible hyparams to be_main
This commit is contained in:
@@ -55,13 +55,13 @@ class Arguments(argparse.ArgumentParser):
|
||||
self._overrides = {}
|
||||
self._subparser = None
|
||||
self.parameters = {
|
||||
"best": [
|
||||
("-b", "--best"),
|
||||
"best_paramfile": [
|
||||
("-b", "--best_paramfile"),
|
||||
{
|
||||
"required": False,
|
||||
"action": "store_true",
|
||||
"required": False,
|
||||
"default": False,
|
||||
"help": "best results of models",
|
||||
"help": "Use best hyperparams file?",
|
||||
},
|
||||
],
|
||||
"color": [
|
||||
@@ -107,7 +107,7 @@ class Arguments(argparse.ArgumentParser):
|
||||
"required": False,
|
||||
"action": "store_true",
|
||||
"default": False,
|
||||
"help": "Use best hyperparams file?",
|
||||
"help": "Use grid output hyperparams file?",
|
||||
},
|
||||
],
|
||||
"hidden": [
|
||||
@@ -198,15 +198,6 @@ class Arguments(argparse.ArgumentParser):
|
||||
"help": "number of folds",
|
||||
},
|
||||
],
|
||||
"paramfile": [
|
||||
("-f", "--paramfile"),
|
||||
{
|
||||
"action": "store_true",
|
||||
"required": False,
|
||||
"default": False,
|
||||
"help": "Use best hyperparams file?",
|
||||
},
|
||||
],
|
||||
"platform": [
|
||||
("-P", "--platform"),
|
||||
{
|
||||
@@ -314,6 +305,12 @@ class Arguments(argparse.ArgumentParser):
|
||||
# Order of args is important
|
||||
parser.add_argument(*names, **{**args, **parameters})
|
||||
|
||||
def add_exclusive(self, hyperparameters, required=False):
|
||||
group = self.add_mutually_exclusive_group(required=required)
|
||||
for name in hyperparameters:
|
||||
names, parameters = self.parameters[name]
|
||||
group.add_argument(*names, **parameters)
|
||||
|
||||
def parse(self, args=None):
|
||||
for key, (dest_key, value) in self._overrides.items():
|
||||
if args is None:
|
||||
|
@@ -10,18 +10,20 @@ from benchmark.Arguments import Arguments
|
||||
|
||||
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
arguments = Arguments(prog="be_main")
|
||||
arguments.xset("stratified").xset("score").xset("model", mandatory=True)
|
||||
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
|
||||
arguments.xset("hyperparameters").xset("paramfile").xset("report")
|
||||
arguments.xset("grid_paramfile")
|
||||
arguments.xset("report")
|
||||
arguments.add_exclusive(
|
||||
["grid_paramfile", "best_paramfile", "hyperparameters"]
|
||||
)
|
||||
arguments.xset(
|
||||
"dataset", overrides="title", const="Test with only one dataset"
|
||||
)
|
||||
args = arguments.parse(args_test)
|
||||
report = args.report or args.dataset is not None
|
||||
if args.grid_paramfile:
|
||||
args.paramfile = False
|
||||
args.best_paramfile = False
|
||||
try:
|
||||
job = Experiment(
|
||||
score_name=args.score,
|
||||
@@ -29,7 +31,7 @@ def main(args_test=None):
|
||||
stratified=args.stratified,
|
||||
datasets=Datasets(dataset_name=args.dataset),
|
||||
hyperparams_dict=args.hyperparameters,
|
||||
hyperparams_file=args.paramfile,
|
||||
hyperparams_file=args.best_paramfile,
|
||||
grid_paramfile=args.grid_paramfile,
|
||||
progress_bar=not args.quiet,
|
||||
platform=args.platform,
|
||||
|
@@ -24,6 +24,7 @@ class ArgumentsTest(TestBase):
|
||||
|
||||
def test_parameters(self):
|
||||
expected_parameters = {
|
||||
"best_paramfile": ("-b", "--best_paramfile"),
|
||||
"color": ("-c", "--color"),
|
||||
"compare": ("-c", "--compare"),
|
||||
"dataset": ("-d", "--dataset"),
|
||||
@@ -39,7 +40,6 @@ class ArgumentsTest(TestBase):
|
||||
"nan": ("--nan",),
|
||||
"number": ("-n", "--number"),
|
||||
"n_folds": ("-n", "--n_folds"),
|
||||
"paramfile": ("-f", "--paramfile"),
|
||||
"platform": ("-P", "--platform"),
|
||||
"quiet": ("-q", "--quiet"),
|
||||
"report": ("-r", "--report"),
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
from ...Results import Report
|
||||
@@ -66,7 +67,7 @@ class BeMainTest(TestBase):
|
||||
"STree",
|
||||
"--title",
|
||||
"test",
|
||||
"-f",
|
||||
"-b",
|
||||
"-r",
|
||||
],
|
||||
)
|
||||
@@ -77,6 +78,48 @@ class BeMainTest(TestBase):
|
||||
stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
|
||||
)
|
||||
|
||||
@patch("sys.stdout", new_callable=StringIO)
|
||||
@patch("sys.stderr", new_callable=StringIO)
|
||||
def test_be_main_incompatible_params(self, stdout, stderr):
|
||||
m1 = (
|
||||
"be_main: error: argument -b/--best_paramfile: not allowed with "
|
||||
"argument -p/--hyperparameters"
|
||||
)
|
||||
m2 = (
|
||||
"be_main: error: argument -g/--grid_paramfile: not allowed with "
|
||||
"argument -p/--hyperparameters"
|
||||
)
|
||||
m3 = (
|
||||
"be_main: error: argument -g/--grid_paramfile: not allowed with "
|
||||
"argument -p/--hyperparameters"
|
||||
)
|
||||
m4 = m1
|
||||
p0 = [
|
||||
"-s",
|
||||
self.score,
|
||||
"-m",
|
||||
"SVC",
|
||||
"--title",
|
||||
"test",
|
||||
]
|
||||
pset = json.dumps(dict(C=17))
|
||||
p1 = p0.copy()
|
||||
p1.extend(["-p", pset, "-b"])
|
||||
p2 = p0.copy()
|
||||
p2.extend(["-p", pset, "-g"])
|
||||
p3 = p0.copy()
|
||||
p3.extend(["-p", pset, "-g", "-b"])
|
||||
p4 = p0.copy()
|
||||
p4.extend(["-b", "-g"])
|
||||
parameters = [(p1, m1), (p2, m2), (p3, m3), (p4, m4)]
|
||||
for parameter, message in parameters:
|
||||
with self.assertRaises(SystemExit) as msg:
|
||||
module = self.search_script("be_main")
|
||||
module.main(parameter)
|
||||
self.assertEqual(msg.exception.code, 2)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.assertRegexpMatches(stdout.getvalue(), message)
|
||||
|
||||
def test_be_main_best_params_non_existent(self):
|
||||
model = "GBC"
|
||||
stdout, stderr = self.execute_script(
|
||||
@@ -88,7 +131,7 @@ class BeMainTest(TestBase):
|
||||
model,
|
||||
"--title",
|
||||
"test",
|
||||
"-f",
|
||||
"-b",
|
||||
"-r",
|
||||
],
|
||||
)
|
||||
|
Reference in New Issue
Block a user