mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-18 08:55:53 +00:00
Add incompatible hyparams to be_main
This commit is contained in:
@@ -55,13 +55,13 @@ class Arguments(argparse.ArgumentParser):
|
|||||||
self._overrides = {}
|
self._overrides = {}
|
||||||
self._subparser = None
|
self._subparser = None
|
||||||
self.parameters = {
|
self.parameters = {
|
||||||
"best": [
|
"best_paramfile": [
|
||||||
("-b", "--best"),
|
("-b", "--best_paramfile"),
|
||||||
{
|
{
|
||||||
"required": False,
|
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
|
"required": False,
|
||||||
"default": False,
|
"default": False,
|
||||||
"help": "best results of models",
|
"help": "Use best hyperparams file?",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"color": [
|
"color": [
|
||||||
@@ -107,7 +107,7 @@ class Arguments(argparse.ArgumentParser):
|
|||||||
"required": False,
|
"required": False,
|
||||||
"action": "store_true",
|
"action": "store_true",
|
||||||
"default": False,
|
"default": False,
|
||||||
"help": "Use best hyperparams file?",
|
"help": "Use grid output hyperparams file?",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"hidden": [
|
"hidden": [
|
||||||
@@ -198,15 +198,6 @@ class Arguments(argparse.ArgumentParser):
|
|||||||
"help": "number of folds",
|
"help": "number of folds",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"paramfile": [
|
|
||||||
("-f", "--paramfile"),
|
|
||||||
{
|
|
||||||
"action": "store_true",
|
|
||||||
"required": False,
|
|
||||||
"default": False,
|
|
||||||
"help": "Use best hyperparams file?",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"platform": [
|
"platform": [
|
||||||
("-P", "--platform"),
|
("-P", "--platform"),
|
||||||
{
|
{
|
||||||
@@ -314,6 +305,12 @@ class Arguments(argparse.ArgumentParser):
|
|||||||
# Order of args is important
|
# Order of args is important
|
||||||
parser.add_argument(*names, **{**args, **parameters})
|
parser.add_argument(*names, **{**args, **parameters})
|
||||||
|
|
||||||
|
def add_exclusive(self, hyperparameters, required=False):
|
||||||
|
group = self.add_mutually_exclusive_group(required=required)
|
||||||
|
for name in hyperparameters:
|
||||||
|
names, parameters = self.parameters[name]
|
||||||
|
group.add_argument(*names, **parameters)
|
||||||
|
|
||||||
def parse(self, args=None):
|
def parse(self, args=None):
|
||||||
for key, (dest_key, value) in self._overrides.items():
|
for key, (dest_key, value) in self._overrides.items():
|
||||||
if args is None:
|
if args is None:
|
||||||
|
@@ -10,18 +10,20 @@ from benchmark.Arguments import Arguments
|
|||||||
|
|
||||||
|
|
||||||
def main(args_test=None):
|
def main(args_test=None):
|
||||||
arguments = Arguments()
|
arguments = Arguments(prog="be_main")
|
||||||
arguments.xset("stratified").xset("score").xset("model", mandatory=True)
|
arguments.xset("stratified").xset("score").xset("model", mandatory=True)
|
||||||
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
|
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
|
||||||
arguments.xset("hyperparameters").xset("paramfile").xset("report")
|
arguments.xset("report")
|
||||||
arguments.xset("grid_paramfile")
|
arguments.add_exclusive(
|
||||||
|
["grid_paramfile", "best_paramfile", "hyperparameters"]
|
||||||
|
)
|
||||||
arguments.xset(
|
arguments.xset(
|
||||||
"dataset", overrides="title", const="Test with only one dataset"
|
"dataset", overrides="title", const="Test with only one dataset"
|
||||||
)
|
)
|
||||||
args = arguments.parse(args_test)
|
args = arguments.parse(args_test)
|
||||||
report = args.report or args.dataset is not None
|
report = args.report or args.dataset is not None
|
||||||
if args.grid_paramfile:
|
if args.grid_paramfile:
|
||||||
args.paramfile = False
|
args.best_paramfile = False
|
||||||
try:
|
try:
|
||||||
job = Experiment(
|
job = Experiment(
|
||||||
score_name=args.score,
|
score_name=args.score,
|
||||||
@@ -29,7 +31,7 @@ def main(args_test=None):
|
|||||||
stratified=args.stratified,
|
stratified=args.stratified,
|
||||||
datasets=Datasets(dataset_name=args.dataset),
|
datasets=Datasets(dataset_name=args.dataset),
|
||||||
hyperparams_dict=args.hyperparameters,
|
hyperparams_dict=args.hyperparameters,
|
||||||
hyperparams_file=args.paramfile,
|
hyperparams_file=args.best_paramfile,
|
||||||
grid_paramfile=args.grid_paramfile,
|
grid_paramfile=args.grid_paramfile,
|
||||||
progress_bar=not args.quiet,
|
progress_bar=not args.quiet,
|
||||||
platform=args.platform,
|
platform=args.platform,
|
||||||
|
@@ -24,6 +24,7 @@ class ArgumentsTest(TestBase):
|
|||||||
|
|
||||||
def test_parameters(self):
|
def test_parameters(self):
|
||||||
expected_parameters = {
|
expected_parameters = {
|
||||||
|
"best_paramfile": ("-b", "--best_paramfile"),
|
||||||
"color": ("-c", "--color"),
|
"color": ("-c", "--color"),
|
||||||
"compare": ("-c", "--compare"),
|
"compare": ("-c", "--compare"),
|
||||||
"dataset": ("-d", "--dataset"),
|
"dataset": ("-d", "--dataset"),
|
||||||
@@ -39,7 +40,6 @@ class ArgumentsTest(TestBase):
|
|||||||
"nan": ("--nan",),
|
"nan": ("--nan",),
|
||||||
"number": ("-n", "--number"),
|
"number": ("-n", "--number"),
|
||||||
"n_folds": ("-n", "--n_folds"),
|
"n_folds": ("-n", "--n_folds"),
|
||||||
"paramfile": ("-f", "--paramfile"),
|
|
||||||
"platform": ("-P", "--platform"),
|
"platform": ("-P", "--platform"),
|
||||||
"quiet": ("-q", "--quiet"),
|
"quiet": ("-q", "--quiet"),
|
||||||
"report": ("-r", "--report"),
|
"report": ("-r", "--report"),
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import json
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from ...Results import Report
|
from ...Results import Report
|
||||||
@@ -66,7 +67,7 @@ class BeMainTest(TestBase):
|
|||||||
"STree",
|
"STree",
|
||||||
"--title",
|
"--title",
|
||||||
"test",
|
"test",
|
||||||
"-f",
|
"-b",
|
||||||
"-r",
|
"-r",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -77,6 +78,48 @@ class BeMainTest(TestBase):
|
|||||||
stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
|
stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@patch("sys.stdout", new_callable=StringIO)
|
||||||
|
@patch("sys.stderr", new_callable=StringIO)
|
||||||
|
def test_be_main_incompatible_params(self, stdout, stderr):
|
||||||
|
m1 = (
|
||||||
|
"be_main: error: argument -b/--best_paramfile: not allowed with "
|
||||||
|
"argument -p/--hyperparameters"
|
||||||
|
)
|
||||||
|
m2 = (
|
||||||
|
"be_main: error: argument -g/--grid_paramfile: not allowed with "
|
||||||
|
"argument -p/--hyperparameters"
|
||||||
|
)
|
||||||
|
m3 = (
|
||||||
|
"be_main: error: argument -g/--grid_paramfile: not allowed with "
|
||||||
|
"argument -p/--hyperparameters"
|
||||||
|
)
|
||||||
|
m4 = m1
|
||||||
|
p0 = [
|
||||||
|
"-s",
|
||||||
|
self.score,
|
||||||
|
"-m",
|
||||||
|
"SVC",
|
||||||
|
"--title",
|
||||||
|
"test",
|
||||||
|
]
|
||||||
|
pset = json.dumps(dict(C=17))
|
||||||
|
p1 = p0.copy()
|
||||||
|
p1.extend(["-p", pset, "-b"])
|
||||||
|
p2 = p0.copy()
|
||||||
|
p2.extend(["-p", pset, "-g"])
|
||||||
|
p3 = p0.copy()
|
||||||
|
p3.extend(["-p", pset, "-g", "-b"])
|
||||||
|
p4 = p0.copy()
|
||||||
|
p4.extend(["-b", "-g"])
|
||||||
|
parameters = [(p1, m1), (p2, m2), (p3, m3), (p4, m4)]
|
||||||
|
for parameter, message in parameters:
|
||||||
|
with self.assertRaises(SystemExit) as msg:
|
||||||
|
module = self.search_script("be_main")
|
||||||
|
module.main(parameter)
|
||||||
|
self.assertEqual(msg.exception.code, 2)
|
||||||
|
self.assertEqual(stderr.getvalue(), "")
|
||||||
|
self.assertRegexpMatches(stdout.getvalue(), message)
|
||||||
|
|
||||||
def test_be_main_best_params_non_existent(self):
|
def test_be_main_best_params_non_existent(self):
|
||||||
model = "GBC"
|
model = "GBC"
|
||||||
stdout, stderr = self.execute_script(
|
stdout, stderr = self.execute_script(
|
||||||
@@ -88,7 +131,7 @@ class BeMainTest(TestBase):
|
|||||||
model,
|
model,
|
||||||
"--title",
|
"--title",
|
||||||
"test",
|
"test",
|
||||||
"-f",
|
"-b",
|
||||||
"-r",
|
"-r",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user