diff --git a/benchmark/Arguments.py b/benchmark/Arguments.py index 67a2515..defcabe 100644 --- a/benchmark/Arguments.py +++ b/benchmark/Arguments.py @@ -36,6 +36,7 @@ class EnvDefault(argparse.Action): self, envvar, required=True, default=None, mandatory=False, **kwargs ): self._args = EnvData.load() + self._overrides = {} if required and not mandatory: default = self._args[envvar] required = False @@ -51,20 +52,22 @@ class Arguments: def __init__(self): self.ap = argparse.ArgumentParser() models_data = Models.define_models(random_state=0) + self._overrides = {} self.parameters = { "best": [ ("-b", "--best"), { - "type": str, "required": False, + "action": "store_true", + "default": False, "help": "best results of models", }, ], "color": [ ("-c", "--color"), { - "type": bool, "required": False, + "action": "store_true", "default": False, "help": "use colors for the tree", }, @@ -72,8 +75,9 @@ class Arguments: "compare": [ ("-c", "--compare"), { - "type": bool, + "action": "store_true", "required": False, + "default": False, "help": "Compare accuracy with best results", }, ], @@ -81,6 +85,8 @@ class Arguments: ("-d", "--dataset"), { "type": str, + "envvar": "dataset", # for compatiblity with EnvDefault + "action": EnvDefault, "required": False, "help": "dataset to work with", }, @@ -88,8 +94,8 @@ class Arguments: "excel": [ ("-x", "--excel"), { - "type": bool, "required": False, + "action": "store_true", "default": False, "help": "Generate Excel File", }, @@ -101,16 +107,17 @@ class Arguments: "grid": [ ("-g", "--grid"), { - "type": str, + "action": "store_true", "required": False, + "default": False, "help": "grid results of model", }, ], "grid_paramfile": [ ("-g", "--grid_paramfile"), { - "type": bool, "required": False, + "action": "store_true", "default": False, "help": "Use best hyperparams file?", }, @@ -118,8 +125,8 @@ class Arguments: "hidden": [ ("--hidden",), { - "type": str, "required": False, + "action": "store_true", "default": False, "help": "Show hidden results", }, @@ -140,8 +147,8 @@ class Arguments: "lose": [ ("-l", "--lose"), { - "type": bool, "default": False, + "action": "store_true", "required": False, "help": "show lose results", }, @@ -178,8 +185,9 @@ class Arguments: "nan": [ ("--nan",), { - "type": bool, + "action": "store_true", "required": False, + "default": False, "help": "Move nan results to hidden folder", }, ], @@ -205,7 +213,7 @@ class Arguments: "paramfile": [ ("-f", "--paramfile"), { - "type": bool, + "action": "store_true", "required": False, "default": False, "help": "Use best hyperparams file?", @@ -224,7 +232,7 @@ class Arguments: "quiet": [ ("-q", "--quiet"), { - "type": bool, + "action": "store_true", "required": False, "default": False, }, @@ -232,7 +240,7 @@ class Arguments: "report": [ ("-r", "--report"), { - "type": bool, + "action": "store_true", "default": False, "required": False, "help": "Report results", @@ -250,23 +258,27 @@ class Arguments: ], "sql": [ ("-q", "--sql"), - {"type": bool, "required": False, "help": "Generate SQL File"}, + { + "required": False, + "action": "store_true", + "default": False, + "help": "Generate SQL File", + }, ], "stratified": [ ("-t", "--stratified"), { "action": EnvDefault, "envvar": "stratified", - "type": str, - "required": True, + "required": False, "help": "Stratified", }, ], "tex_output": [ ("-t", "--tex-output"), { - "type": bool, "required": False, + "action": "store_true", "default": False, "help": "Generate Tex file with the table", }, @@ -278,8 +290,8 @@ class Arguments: "win": [ ("-w", "--win"), { - "type": bool, "default": False, + "action": "store_true", "required": False, "help": "show win results", }, @@ -287,12 +299,20 @@ class Arguments: } def xset(self, *arg_name, **kwargs): - names, default = self.parameters[arg_name[0]] + names, parameters = self.parameters[arg_name[0]] + if "overrides" in kwargs: + self._overrides[names[0]] = (kwargs["overrides"], kwargs["const"]) + del kwargs["overrides"] self.ap.add_argument( *names, - **{**default, **kwargs}, + **{**parameters, **kwargs}, ) return self def parse(self, args=None): + for key, (dest_key, value) in self._overrides.items(): + if args is None: + args = sys.argv[1:] + if key in args: + args.extend((f"--{dest_key}", value)) return self.ap.parse_args(args) diff --git a/benchmark/scripts/be_main.py b/benchmark/scripts/be_main.py index dcd8b0e..33a3428 100755 --- a/benchmark/scripts/be_main.py +++ b/benchmark/scripts/be_main.py @@ -14,7 +14,10 @@ def main(args_test=None): arguments.xset("stratified").xset("score").xset("model", mandatory=True) arguments.xset("n_folds").xset("platform").xset("quiet").xset("title") arguments.xset("hyperparameters").xset("paramfile").xset("report") - arguments.xset("grid_paramfile").xset("dataset") + arguments.xset("grid_paramfile") + arguments.xset( + "dataset", overrides="title", const="Test with only one dataset" + ) args = arguments.parse(args_test) report = args.report or args.dataset is not None if args.grid_paramfile: diff --git a/benchmark/scripts/be_report.py b/benchmark/scripts/be_report.py index 8ee44f0..8407e6a 100755 --- a/benchmark/scripts/be_report.py +++ b/benchmark/scripts/be_report.py @@ -17,17 +17,17 @@ def main(args_test=None): arguments.xset("score", required=False) args = arguments.parse(args_test) if args.best: - args.grid = None + args.grid = False if args.grid: - args.best = None - if args.file is None and args.best is None and args.grid is None: + args.best = False + if args.file is None and not args.best and not args.grid: report = ReportDatasets(args.excel) report.report() if args.excel: is_test = args_test is not None Files.open(report.get_file_name(), is_test) else: - if args.best is not None or args.grid is not None: + if args.best or args.grid: report = ReportBest(args.score, args.model, args.best, args.grid) report.report() else: diff --git a/benchmark/tests/Arguments_test.py b/benchmark/tests/Arguments_test.py index 312bed1..b56ee0e 100644 --- a/benchmark/tests/Arguments_test.py +++ b/benchmark/tests/Arguments_test.py @@ -98,3 +98,27 @@ class ArgumentsTest(TestBase): finally: os.chdir(path) self.assertEqual(stderr.getvalue(), f"{NO_ENV}\n") + + @patch("sys.stderr", new_callable=StringIO) + def test_overrides(self, stderr): + arguments = self.build_args() + arguments.xset("title") + arguments.xset("dataset", overrides="title", const="sample text") + test_args = ["-n", "3", "-m", "SVC", "-k", "1", "-d", "dataset"] + args = arguments.parse(test_args) + self.assertEqual(stderr.getvalue(), "") + self.assertEqual(args.title, "sample text") + + @patch("sys.stderr", new_callable=StringIO) + def test_overrides_no_args(self, stderr): + arguments = self.build_args() + arguments.xset("title") + arguments.xset("dataset", overrides="title", const="sample text") + test_args = None + with self.assertRaises(SystemExit): + arguments.parse(test_args) + self.assertRegexpMatches( + stderr.getvalue(), + r"error: the following arguments are required: -m/--model, " + "-k/--key, --title", + ) diff --git a/benchmark/tests/scripts/Be_Benchmark_test.py b/benchmark/tests/scripts/Be_Benchmark_test.py index 00dc168..0249473 100644 --- a/benchmark/tests/scripts/Be_Benchmark_test.py +++ b/benchmark/tests/scripts/Be_Benchmark_test.py @@ -25,7 +25,7 @@ class BeBenchmarkTest(TestBase): def test_be_benchmark_complete(self): stdout, stderr = self.execute_script( - "be_benchmark", ["-s", self.score, "-q", "1", "-t", "1", "-x", "1"] + "be_benchmark", ["-s", self.score, "-q", "-t", "-x"] ) self.assertEqual(stderr.getvalue(), "") # Check output @@ -60,7 +60,7 @@ class BeBenchmarkTest(TestBase): def test_be_benchmark_single(self): stdout, stderr = self.execute_script( - "be_benchmark", ["-s", self.score, "-q", "1"] + "be_benchmark", ["-s", self.score, "-q"] ) self.assertEqual(stderr.getvalue(), "") # Check output diff --git a/benchmark/tests/scripts/Be_Best_test.py b/benchmark/tests/scripts/Be_Best_test.py index 1d61f82..f6e7978 100644 --- a/benchmark/tests/scripts/Be_Best_test.py +++ b/benchmark/tests/scripts/Be_Best_test.py @@ -67,7 +67,7 @@ class BeBestTest(TestBase): def test_be_build_best_report(self): stdout, _ = self.execute_script( - "be_build_best", ["-s", "accuracy", "-m", "ODTE", "-r", "1"] + "be_build_best", ["-s", "accuracy", "-m", "ODTE", "-r"] ) expected_data = { "balance-scale": [ diff --git a/benchmark/tests/scripts/Be_Grid_test.py b/benchmark/tests/scripts/Be_Grid_test.py index 1adb296..029b851 100644 --- a/benchmark/tests/scripts/Be_Grid_test.py +++ b/benchmark/tests/scripts/Be_Grid_test.py @@ -69,7 +69,7 @@ class BeGridTest(TestBase): def test_be_grid_no_input(self): stdout, stderr = self.execute_script( "be_grid", - ["-m", "ODTE", "-s", "f1-weighted", "-q", "1"], + ["-m", "ODTE", "-s", "f1-weighted", "-q"], ) self.assertEqual(stderr.getvalue(), "") grid_file = os.path.join( diff --git a/benchmark/tests/scripts/Be_List_test.py b/benchmark/tests/scripts/Be_List_test.py index b4011e9..c587622 100644 --- a/benchmark/tests/scripts/Be_List_test.py +++ b/benchmark/tests/scripts/Be_List_test.py @@ -29,9 +29,7 @@ class BeListTest(TestBase): @patch("benchmark.Results.get_input", side_effect=iter(["q"])) def test_be_list_report_excel_none(self, input_data): - stdout, stderr = self.execute_script( - "be_list", ["-m", "STree", "-x", "1"] - ) + stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"]) self.assertEqual(stderr.getvalue(), "") self.check_output_file(stdout, "be_list_model") @@ -43,9 +41,7 @@ class BeListTest(TestBase): @patch("benchmark.Results.get_input", side_effect=iter(["2", "q"])) def test_be_list_report_excel(self, input_data): - stdout, stderr = self.execute_script( - "be_list", ["-m", "STree", "-x", "1"] - ) + stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"]) self.assertEqual(stderr.getvalue(), "") self.check_output_file(stdout, "be_list_report_excel") book = load_workbook(Files.be_list_excel) @@ -54,9 +50,7 @@ class BeListTest(TestBase): @patch("benchmark.Results.get_input", side_effect=iter(["2", "1", "q"])) def test_be_list_report_excel_twice(self, input_data): - stdout, stderr = self.execute_script( - "be_list", ["-m", "STree", "-x", "1"] - ) + stdout, stderr = self.execute_script("be_list", ["-m", "STree", "-x"]) self.assertEqual(stderr.getvalue(), "") self.check_output_file(stdout, "be_list_report_excel_2") book = load_workbook(Files.be_list_excel) @@ -87,7 +81,7 @@ class BeListTest(TestBase): swap_files(Folders.hidden_results, Folders.results, file_name) try: # list and move nan result to hidden - stdout, stderr = self.execute_script("be_list", ["--nan", "1"]) + stdout, stderr = self.execute_script("be_list", ["--nan"]) self.assertEqual(stderr.getvalue(), "") self.check_output_file(stdout, "be_list_nan") except Exception: @@ -97,7 +91,7 @@ class BeListTest(TestBase): @patch("benchmark.Results.get_input", return_value="q") def test_be_list_nan_no_nan(self, input_data): - stdout, stderr = self.execute_script("be_list", ["--nan", "1"]) + stdout, stderr = self.execute_script("be_list", ["--nan"]) self.assertEqual(stderr.getvalue(), "") self.check_output_file(stdout, "be_list_no_nan") diff --git a/benchmark/tests/scripts/Be_Main_test.py b/benchmark/tests/scripts/Be_Main_test.py index 3b17da8..6464174 100644 --- a/benchmark/tests/scripts/Be_Main_test.py +++ b/benchmark/tests/scripts/Be_Main_test.py @@ -30,7 +30,7 @@ class BeMainTest(TestBase): def test_be_main_complete(self): stdout, _ = self.execute_script( "be_main", - ["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"], + ["-s", self.score, "-m", "STree", "--title", "test", "-r"], ) # keep the report name to delete it after report_name = stdout.getvalue().splitlines()[-1].split("in ")[1] @@ -67,9 +67,7 @@ class BeMainTest(TestBase): "--title", "test", "-f", - "1", "-r", - "1", ], ) # keep the report name to delete it after @@ -91,9 +89,7 @@ class BeMainTest(TestBase): "--title", "test", "-f", - "1", "-r", - "1", ], ) self.assertEqual(stderr.getvalue(), "") @@ -117,9 +113,7 @@ class BeMainTest(TestBase): "--title", "test", "-g", - "1", "-r", - "1", ], ) self.assertEqual(stderr.getvalue(), "") @@ -142,9 +136,7 @@ class BeMainTest(TestBase): "--title", "test", "-g", - "1", "-r", - "1", ], ) # keep the report name to delete it after diff --git a/benchmark/tests/scripts/Be_Print_Strees_test.py b/benchmark/tests/scripts/Be_Print_Strees_test.py index 95d64c1..3e7dde9 100644 --- a/benchmark/tests/scripts/Be_Print_Strees_test.py +++ b/benchmark/tests/scripts/Be_Print_Strees_test.py @@ -18,7 +18,7 @@ class BePrintStrees(TestBase): for name in self.datasets: stdout, _ = self.execute_script( "be_print_strees", - ["-d", name, "-q", "1"], + ["-d", name, "-q"], ) file_name = os.path.join(Folders.img, f"stree_{name}.png") self.files.append(file_name) @@ -33,7 +33,7 @@ class BePrintStrees(TestBase): for name in self.datasets: stdout, _ = self.execute_script( "be_print_strees", - ["-d", name, "-q", "1", "-c", "1"], + ["-d", name, "-q", "-c"], ) file_name = os.path.join(Folders.img, f"stree_{name}.png") self.files.append(file_name) diff --git a/benchmark/tests/scripts/Be_Report_test.py b/benchmark/tests/scripts/Be_Report_test.py index 95b1a7c..d8a3980 100644 --- a/benchmark/tests/scripts/Be_Report_test.py +++ b/benchmark/tests/scripts/Be_Report_test.py @@ -35,7 +35,7 @@ class BeReportTest(TestBase): def test_be_report_compare(self): file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" stdout, stderr = self.execute_script( - "be_report", ["-f", file_name, "-c", "1"] + "be_report", ["-f", file_name, "-c"] ) self.assertEqual(stderr.getvalue(), "") self.check_output_file(stdout, "report_compared") @@ -54,7 +54,7 @@ class BeReportTest(TestBase): self.assertEqual(line, output_text[index]) def test_be_report_datasets_excel(self): - stdout, stderr = self.execute_script("be_report", ["-x", "1"]) + stdout, stderr = self.execute_script("be_report", ["-x"]) self.assertEqual(stderr.getvalue(), "") file_name = f"report_datasets{self.ext}" with open(os.path.join(self.test_files, file_name)) as f: @@ -77,14 +77,14 @@ class BeReportTest(TestBase): def test_be_report_best(self): stdout, stderr = self.execute_script( - "be_report", ["-s", "accuracy", "-m", "STree", "-b", "1"] + "be_report", ["-s", "accuracy", "-m", "STree", "-b"] ) self.assertEqual(stderr.getvalue(), "") self.check_output_file(stdout, "report_best") def test_be_report_grid(self): stdout, stderr = self.execute_script( - "be_report", ["-s", "accuracy", "-m", "STree", "-g", "1"] + "be_report", ["-s", "accuracy", "-m", "STree", "-g"] ) self.assertEqual(stderr.getvalue(), "") file_name = "report_grid.test" @@ -101,7 +101,7 @@ class BeReportTest(TestBase): def test_be_report_best_both(self): stdout, stderr = self.execute_script( "be_report", - ["-s", "accuracy", "-m", "STree", "-b", "1", "-g", "1"], + ["-s", "accuracy", "-m", "STree", "-b", "-g"], ) self.assertEqual(stderr.getvalue(), "") self.check_output_file(stdout, "report_best") @@ -110,7 +110,7 @@ class BeReportTest(TestBase): file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" stdout, stderr = self.execute_script( "be_report", - ["-f", file_name, "-x", "1", "-c", "1"], + ["-f", file_name, "-x", "-c"], ) file_name = os.path.join( Folders.results, file_name.replace(".json", ".xlsx") @@ -125,7 +125,7 @@ class BeReportTest(TestBase): file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" stdout, stderr = self.execute_script( "be_report", - ["-f", file_name, "-x", "1"], + ["-f", file_name, "-x"], ) file_name = os.path.join( Folders.results, file_name.replace(".json", ".xlsx") @@ -140,7 +140,7 @@ class BeReportTest(TestBase): file_name = "results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.json" stdout, stderr = self.execute_script( "be_report", - ["-f", file_name, "-q", "1"], + ["-f", file_name, "-q"], ) file_name = os.path.join( Folders.results, file_name.replace(".json", ".sql") diff --git a/benchmark/tests/test_files/be_main_dataset.test b/benchmark/tests/test_files/be_main_dataset.test index 3fedf46..4c56e14 100644 --- a/benchmark/tests/test_files/be_main_dataset.test +++ b/benchmark/tests/test_files/be_main_dataset.test @@ -1,6 +1,6 @@ ************************************************************************************************************************* * STree ver. 1.2.4 Python ver. 3.11x with 5 Folds cross validation and 10 random seeds. 2022-05-08 19:38:28 * -* test * +* Test with only one dataset * * Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False * * Execution took 0.06 seconds, 0.00 hours, on iMac27 * * Score is accuracy *