Add overrides to args parse for dataset/title in be_main

This commit is contained in:
2022-11-19 21:16:29 +01:00
parent 68d9cb776e
commit 07172b91c5
12 changed files with 92 additions and 59 deletions

View File

@@ -36,6 +36,7 @@ class EnvDefault(argparse.Action):
self, envvar, required=True, default=None, mandatory=False, **kwargs self, envvar, required=True, default=None, mandatory=False, **kwargs
): ):
self._args = EnvData.load() self._args = EnvData.load()
self._overrides = {}
if required and not mandatory: if required and not mandatory:
default = self._args[envvar] default = self._args[envvar]
required = False required = False
@@ -51,20 +52,22 @@ class Arguments:
def __init__(self): def __init__(self):
self.ap = argparse.ArgumentParser() self.ap = argparse.ArgumentParser()
models_data = Models.define_models(random_state=0) models_data = Models.define_models(random_state=0)
self._overrides = {}
self.parameters = { self.parameters = {
"best": [ "best": [
("-b", "--best"), ("-b", "--best"),
{ {
"type": str,
"required": False, "required": False,
"action": "store_true",
"default": False,
"help": "best results of models", "help": "best results of models",
}, },
], ],
"color": [ "color": [
("-c", "--color"), ("-c", "--color"),
{ {
"type": bool,
"required": False, "required": False,
"action": "store_true",
"default": False, "default": False,
"help": "use colors for the tree", "help": "use colors for the tree",
}, },
@@ -72,8 +75,9 @@ class Arguments:
"compare": [ "compare": [
("-c", "--compare"), ("-c", "--compare"),
{ {
"type": bool, "action": "store_true",
"required": False, "required": False,
"default": False,
"help": "Compare accuracy with best results", "help": "Compare accuracy with best results",
}, },
], ],
@@ -81,6 +85,8 @@ class Arguments:
("-d", "--dataset"), ("-d", "--dataset"),
{ {
"type": str, "type": str,
"envvar": "dataset", # for compatiblity with EnvDefault
"action": EnvDefault,
"required": False, "required": False,
"help": "dataset to work with", "help": "dataset to work with",
}, },
@@ -88,8 +94,8 @@ class Arguments:
"excel": [ "excel": [
("-x", "--excel"), ("-x", "--excel"),
{ {
"type": bool,
"required": False, "required": False,
"action": "store_true",
"default": False, "default": False,
"help": "Generate Excel File", "help": "Generate Excel File",
}, },
@@ -101,16 +107,17 @@ class Arguments:
"grid": [ "grid": [
("-g", "--grid"), ("-g", "--grid"),
{ {
"type": str, "action": "store_true",
"required": False, "required": False,
"default": False,
"help": "grid results of model", "help": "grid results of model",
}, },
], ],
"grid_paramfile": [ "grid_paramfile": [
("-g", "--grid_paramfile"), ("-g", "--grid_paramfile"),
{ {
"type": bool,
"required": False, "required": False,
"action": "store_true",
"default": False, "default": False,
"help": "Use best hyperparams file?", "help": "Use best hyperparams file?",
}, },
@@ -118,8 +125,8 @@ class Arguments:
"hidden": [ "hidden": [
("--hidden",), ("--hidden",),
{ {
"type": str,
"required": False, "required": False,
"action": "store_true",
"default": False, "default": False,
"help": "Show hidden results", "help": "Show hidden results",
}, },
@@ -140,8 +147,8 @@ class Arguments:
"lose": [ "lose": [
("-l", "--lose"), ("-l", "--lose"),
{ {
"type": bool,
"default": False, "default": False,
"action": "store_true",
"required": False, "required": False,
"help": "show lose results", "help": "show lose results",
}, },
@@ -178,8 +185,9 @@ class Arguments:
"nan": [ "nan": [
("--nan",), ("--nan",),
{ {
"type": bool, "action": "store_true",
"required": False, "required": False,
"default": False,
"help": "Move nan results to hidden folder", "help": "Move nan results to hidden folder",
}, },
], ],
@@ -205,7 +213,7 @@ class Arguments:
"paramfile": [ "paramfile": [
("-f", "--paramfile"), ("-f", "--paramfile"),
{ {
"type": bool, "action": "store_true",
"required": False, "required": False,
"default": False, "default": False,
"help": "Use best hyperparams file?", "help": "Use best hyperparams file?",
@@ -224,7 +232,7 @@ class Arguments:
"quiet": [ "quiet": [
("-q", "--quiet"), ("-q", "--quiet"),
{ {
"type": bool, "action": "store_true",
"required": False, "required": False,
"default": False, "default": False,
}, },
@@ -232,7 +240,7 @@ class Arguments:
"report": [ "report": [
("-r", "--report"), ("-r", "--report"),
{ {
"type": bool, "action": "store_true",
"default": False, "default": False,
"required": False, "required": False,
"help": "Report results", "help": "Report results",
@@ -250,23 +258,27 @@ class Arguments:
], ],
"sql": [ "sql": [
("-q", "--sql"), ("-q", "--sql"),
{"type": bool, "required": False, "help": "Generate SQL File"}, {
"required": False,
"action": "store_true",
"default": False,
"help": "Generate SQL File",
},
], ],
"stratified": [ "stratified": [
("-t", "--stratified"), ("-t", "--stratified"),
{ {
"action": EnvDefault, "action": EnvDefault,
"envvar": "stratified", "envvar": "stratified",
"type": str, "required": False,
"required": True,
"help": "Stratified", "help": "Stratified",
}, },
], ],
"tex_output": [ "tex_output": [
("-t", "--tex-output"), ("-t", "--tex-output"),
{ {
"type": bool,
"required": False, "required": False,
"action": "store_true",
"default": False, "default": False,
"help": "Generate Tex file with the table", "help": "Generate Tex file with the table",
}, },
@@ -278,8 +290,8 @@ class Arguments:
"win": [ "win": [
("-w", "--win"), ("-w", "--win"),
{ {
"type": bool,
"default": False, "default": False,
"action": "store_true",
"required": False, "required": False,
"help": "show win results", "help": "show win results",
}, },
@@ -287,12 +299,20 @@ class Arguments:
} }
def xset(self, *arg_name, **kwargs): def xset(self, *arg_name, **kwargs):
names, default = self.parameters[arg_name[0]] names, parameters = self.parameters[arg_name[0]]
if "overrides" in kwargs:
self._overrides[names[0]] = (kwargs["overrides"], kwargs["const"])
del kwargs["overrides"]
self.ap.add_argument( self.ap.add_argument(
*names, *names,
**{**default, **kwargs}, **{**parameters, **kwargs},
) )
return self return self
def parse(self, args=None): def parse(self, args=None):
for key, (dest_key, value) in self._overrides.items():
if args is None:
args = sys.argv[1:]
if key in args:
args.extend((f"--{dest_key}", value))
return self.ap.parse_args(args) return self.ap.parse_args(args)

View File

@@ -14,7 +14,10 @@ def main(args_test=None):
arguments.xset("stratified").xset("score").xset("model", mandatory=True) arguments.xset("stratified").xset("score").xset("model", mandatory=True)
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title") arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
arguments.xset("hyperparameters").xset("paramfile").xset("report") arguments.xset("hyperparameters").xset("paramfile").xset("report")
arguments.xset("grid_paramfile").xset("dataset") arguments.xset("grid_paramfile")
arguments.xset(
"dataset", overrides="title", const="Test with only one dataset"
)
args = arguments.parse(args_test) args = arguments.parse(args_test)
report = args.report or args.dataset is not None report = args.report or args.dataset is not None
if args.grid_paramfile: if args.grid_paramfile:

View File

@@ -17,17 +17,17 @@ def main(args_test=None):
arguments.xset("score", required=False) arguments.xset("score", required=False)
args = arguments.parse(args_test) args = arguments.parse(args_test)
if args.best: if args.best:
args.grid = None args.grid = False
if args.grid: if args.grid:
args.best = None args.best = False
if args.file is None and args.best is None and args.grid is None: if args.file is None and not args.best and not args.grid:
report = ReportDatasets(args.excel) report = ReportDatasets(args.excel)
report.report() report.report()
if args.excel: if args.excel:
is_test = args_test is not None is_test = args_test is not None
Files.open(report.get_file_name(), is_test) Files.open(report.get_file_name(), is_test)
else: else:
if args.best is not None or args.grid is not None: if args.best or args.grid:
report = ReportBest(args.score, args.model, args.best, args.grid) report = ReportBest(args.score, args.model, args.best, args.grid)
report.report() report.report()
else: else:

View File

@@ -98,3 +98,27 @@ class ArgumentsTest(TestBase):
finally: finally:
os.chdir(path) os.chdir(path)
self.assertEqual(stderr.getvalue(), f"{NO_ENV}\n") self.assertEqual(stderr.getvalue(), f"{NO_ENV}\n")
@patch("sys.stderr", new_callable=StringIO)
def test_overrides(self, stderr):
arguments = self.build_args()
arguments.xset("title")
arguments.xset("dataset", overrides="title", const="sample text")
test_args = ["-n", "3", "-m", "SVC", "-k", "1", "-d", "dataset"]
args = arguments.parse(test_args)
self.assertEqual(stderr.getvalue(), "")
self.assertEqual(args.title, "sample text")
@patch("sys.stderr", new_callable=StringIO)
def test_overrides_no_args(self, stderr):
arguments = self.build_args()
arguments.xset("title")
arguments.xset("dataset", overrides="title", const="sample text")
test_args = None
with self.assertRaises(SystemExit):
arguments.parse(test_args)
self.assertRegexpMatches(
stderr.getvalue(),
r"error: the following arguments are required: -m/--model, "
"-k/--key, --title",
)

View File

@@ -25,7 +25,7 @@ class BeBenchmarkTest(TestBase):
def test_be_benchmark_complete(self): def test_be_benchmark_complete(self):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_benchmark", ["-s", self.score, "-q", "1", "-t", "1", "-x", "1"] "be_benchmark", ["-s", self.score, "-q", "-t", "-x"]
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
# Check output # Check output
@@ -60,7 +60,7 @@ class BeBenchmarkTest(TestBase):
def test_be_benchmark_single(self): def test_be_benchmark_single(self):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_benchmark", ["-s", self.score, "-q", "1"] "be_benchmark", ["-s", self.score, "-q"]
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
# Check output # Check output

View File

@@ -67,7 +67,7 @@ class BeBestTest(TestBase):
def test_be_build_best_report(self): def test_be_build_best_report(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_build_best", ["-s", "accuracy", "-m", "ODTE", "-r", "1"] "be_build_best", ["-s", "accuracy", "-m", "ODTE", "-r"]
) )
expected_data = { expected_data = {
"balance-scale": [ "balance-scale": [

View File

@@ -69,7 +69,7 @@ class BeGridTest(TestBase):
def test_be_grid_no_input(self): def test_be_grid_no_input(self):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_grid", "be_grid",
["-m", "ODTE", "-s", "f1-weighted", "-q", "1"], ["-m", "ODTE", "-s", "f1-weighted", "-q"],
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
grid_file = os.path.join( grid_file = os.path.join(

View File

@@ -29,9 +29,7 @@ class BeListTest(TestBase):
@patch("benchmark.Results.get_input", side_effect=iter(["q"])) @patch("benchmark.Results.get_input", side_effect=iter(["q"]))
def test_be_list_report_excel_none(self, input_data): def test_be_list_report_excel_none(self, input_data):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"])
"be_list", ["-m", "STree", "-x", "1"]
)
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "be_list_model") self.check_output_file(stdout, "be_list_model")
@@ -43,9 +41,7 @@ class BeListTest(TestBase):
@patch("benchmark.Results.get_input", side_effect=iter(["2", "q"])) @patch("benchmark.Results.get_input", side_effect=iter(["2", "q"]))
def test_be_list_report_excel(self, input_data): def test_be_list_report_excel(self, input_data):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"])
"be_list", ["-m", "STree", "-x", "1"]
)
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "be_list_report_excel") self.check_output_file(stdout, "be_list_report_excel")
book = load_workbook(Files.be_list_excel) book = load_workbook(Files.be_list_excel)
@@ -54,9 +50,7 @@ class BeListTest(TestBase):
@patch("benchmark.Results.get_input", side_effect=iter(["2", "1", "q"])) @patch("benchmark.Results.get_input", side_effect=iter(["2", "1", "q"]))
def test_be_list_report_excel_twice(self, input_data): def test_be_list_report_excel_twice(self, input_data):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"])
"be_list", ["-m", "STree", "-x", "1"]
)
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "be_list_report_excel_2") self.check_output_file(stdout, "be_list_report_excel_2")
book = load_workbook(Files.be_list_excel) book = load_workbook(Files.be_list_excel)
@@ -87,7 +81,7 @@ class BeListTest(TestBase):
swap_files(Folders.hidden_results, Folders.results, file_name) swap_files(Folders.hidden_results, Folders.results, file_name)
try: try:
# list and move nan result to hidden # list and move nan result to hidden
stdout, stderr = self.execute_script("be_list", ["--nan", "1"]) stdout, stderr = self.execute_script("be_list", ["--nan"])
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "be_list_nan") self.check_output_file(stdout, "be_list_nan")
except Exception: except Exception:
@@ -97,7 +91,7 @@ class BeListTest(TestBase):
@patch("benchmark.Results.get_input", return_value="q") @patch("benchmark.Results.get_input", return_value="q")
def test_be_list_nan_no_nan(self, input_data): def test_be_list_nan_no_nan(self, input_data):
stdout, stderr = self.execute_script("be_list", ["--nan", "1"]) stdout, stderr = self.execute_script("be_list", ["--nan"])
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "be_list_no_nan") self.check_output_file(stdout, "be_list_no_nan")

View File

@@ -30,7 +30,7 @@ class BeMainTest(TestBase):
def test_be_main_complete(self): def test_be_main_complete(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",
["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"], ["-s", self.score, "-m", "STree", "--title", "test", "-r"],
) )
# keep the report name to delete it after # keep the report name to delete it after
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1] report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
@@ -67,9 +67,7 @@ class BeMainTest(TestBase):
"--title", "--title",
"test", "test",
"-f", "-f",
"1",
"-r", "-r",
"1",
], ],
) )
# keep the report name to delete it after # keep the report name to delete it after
@@ -91,9 +89,7 @@ class BeMainTest(TestBase):
"--title", "--title",
"test", "test",
"-f", "-f",
"1",
"-r", "-r",
"1",
], ],
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
@@ -117,9 +113,7 @@ class BeMainTest(TestBase):
"--title", "--title",
"test", "test",
"-g", "-g",
"1",
"-r", "-r",
"1",
], ],
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
@@ -142,9 +136,7 @@ class BeMainTest(TestBase):
"--title", "--title",
"test", "test",
"-g", "-g",
"1",
"-r", "-r",
"1",
], ],
) )
# keep the report name to delete it after # keep the report name to delete it after

View File

@@ -18,7 +18,7 @@ class BePrintStrees(TestBase):
for name in self.datasets: for name in self.datasets:
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_print_strees", "be_print_strees",
["-d", name, "-q", "1"], ["-d", name, "-q"],
) )
file_name = os.path.join(Folders.img, f"stree_{name}.png") file_name = os.path.join(Folders.img, f"stree_{name}.png")
self.files.append(file_name) self.files.append(file_name)
@@ -33,7 +33,7 @@ class BePrintStrees(TestBase):
for name in self.datasets: for name in self.datasets:
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_print_strees", "be_print_strees",
["-d", name, "-q", "1", "-c", "1"], ["-d", name, "-q", "-c"],
) )
file_name = os.path.join(Folders.img, f"stree_{name}.png") file_name = os.path.join(Folders.img, f"stree_{name}.png")
self.files.append(file_name) self.files.append(file_name)

View File

@@ -35,7 +35,7 @@ class BeReportTest(TestBase):
def test_be_report_compare(self): def test_be_report_compare(self):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", ["-f", file_name, "-c", "1"] "be_report", ["-f", file_name, "-c"]
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_compared") self.check_output_file(stdout, "report_compared")
@@ -54,7 +54,7 @@ class BeReportTest(TestBase):
self.assertEqual(line, output_text[index]) self.assertEqual(line, output_text[index])
def test_be_report_datasets_excel(self): def test_be_report_datasets_excel(self):
stdout, stderr = self.execute_script("be_report", ["-x", "1"]) stdout, stderr = self.execute_script("be_report", ["-x"])
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
file_name = f"report_datasets{self.ext}" file_name = f"report_datasets{self.ext}"
with open(os.path.join(self.test_files, file_name)) as f: with open(os.path.join(self.test_files, file_name)) as f:
@@ -77,14 +77,14 @@ class BeReportTest(TestBase):
def test_be_report_best(self): def test_be_report_best(self):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", ["-s", "accuracy", "-m", "STree", "-b", "1"] "be_report", ["-s", "accuracy", "-m", "STree", "-b"]
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_best") self.check_output_file(stdout, "report_best")
def test_be_report_grid(self): def test_be_report_grid(self):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", ["-s", "accuracy", "-m", "STree", "-g", "1"] "be_report", ["-s", "accuracy", "-m", "STree", "-g"]
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
file_name = "report_grid.test" file_name = "report_grid.test"
@@ -101,7 +101,7 @@ class BeReportTest(TestBase):
def test_be_report_best_both(self): def test_be_report_best_both(self):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", "be_report",
["-s", "accuracy", "-m", "STree", "-b", "1", "-g", "1"], ["-s", "accuracy", "-m", "STree", "-b", "-g"],
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_best") self.check_output_file(stdout, "report_best")
@@ -110,7 +110,7 @@ class BeReportTest(TestBase):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", "be_report",
["-f", file_name, "-x", "1", "-c", "1"], ["-f", file_name, "-x", "-c"],
) )
file_name = os.path.join( file_name = os.path.join(
Folders.results, file_name.replace(".json", ".xlsx") Folders.results, file_name.replace(".json", ".xlsx")
@@ -125,7 +125,7 @@ class BeReportTest(TestBase):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", "be_report",
["-f", file_name, "-x", "1"], ["-f", file_name, "-x"],
) )
file_name = os.path.join( file_name = os.path.join(
Folders.results, file_name.replace(".json", ".xlsx") Folders.results, file_name.replace(".json", ".xlsx")
@@ -140,7 +140,7 @@ class BeReportTest(TestBase):
file_name = "results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json" file_name = "results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json"
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_report", "be_report",
["-f", file_name, "-q", "1"], ["-f", file_name, "-q"],
) )
file_name = os.path.join( file_name = os.path.join(
Folders.results, file_name.replace(".json", ".sql") Folders.results, file_name.replace(".json", ".sql")

View File

@@ -1,6 +1,6 @@
************************************************************************************************************************* *************************************************************************************************************************
* STree ver. 1.2.4 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-05-08 19:38:28 * * STree ver. 1.2.4 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-05-08 19:38:28 *
* test * * Test with only one dataset *
* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False * * Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
* Execution took 0.06 seconds, 0.00 hours, on iMac27 * * Execution took 0.06 seconds, 0.00 hours, on iMac27 *
* Score is accuracy * * Score is accuracy *