diff --git a/.coveragerc b/.coveragerc index a709d91..232a67e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -11,4 +11,5 @@ exclude_lines = ignore_errors = True omit = benchmark/__init__.py - benchmark/_version.py \ No newline at end of file + benchmark/_version.py + benchmark/tests/* \ No newline at end of file diff --git a/.gitignore b/.gitignore index a15cb7d..dc31840 100644 --- a/.gitignore +++ b/.gitignore @@ -102,7 +102,7 @@ celerybeat.pid *.sage.py # Environments -.env +/.env .venv env/ venv/ diff --git a/benchmark/Utils.py b/benchmark/Utils.py index 576810b..baf6539 100644 --- a/benchmark/Utils.py +++ b/benchmark/Utils.py @@ -102,14 +102,13 @@ class Files: return None def get_all_results(self, hidden) -> list[str]: - first_path = "." - first_try = os.path.join( - first_path, Folders.hidden_results if hidden else Folders.results + result_path = os.path.join( + ".", Folders.hidden_results if hidden else Folders.results ) - if os.path.isdir(first_try): - files_list = os.listdir(first_try) + if os.path.isdir(result_path): + files_list = os.listdir(result_path) else: - raise ValueError(f"{first_try} does not exist") + raise ValueError(f"{result_path} does not exist") result = [] prefix, suffix = self.results_suffixes() for result_file in files_list: @@ -143,17 +142,12 @@ class EnvDefault(argparse.Action): # Thanks to https://stackoverflow.com/users/445507/russell-heilling def __init__(self, envvar, required=True, default=None, **kwargs): self._args = EnvData.load() - if not default and envvar in self._args: - default = self._args[envvar] - if required and default: - required = False + default = self._args[envvar] + required = False super(EnvDefault, self).__init__( default=default, required=required, **kwargs ) - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, values) - class TextColor: BLUE = "\033[94m" diff --git a/benchmark/tests/.env b/benchmark/tests/.env new file mode 100644 index 0000000..819f93f --- /dev/null +++ b/benchmark/tests/.env @@ -0,0 +1,7 @@ +score=accuracy +platform=iMac27 +n_folds=5 +model=ODTE +stratified=0 +# Source of data Tanveer/Surcov +source_data=Tanveer diff --git a/benchmark/tests/Util_test.py b/benchmark/tests/Util_test.py index 8b0f62a..04c00d4 100644 --- a/benchmark/tests/Util_test.py +++ b/benchmark/tests/Util_test.py @@ -132,6 +132,29 @@ class UtilTest(unittest.TestCase): ["results_accuracy_STree_iMac27_2021-11-01_23:55:16_0.json"], ) + def test_Files_get_results_Error(self): + os.chdir(os.path.dirname(os.path.abspath(__file__))) + # check with results + os.rename(Folders.results, f"{Folders.results}.test") + try: + Files().get_all_results(hidden=False) + except ValueError: + pass + else: + self.fail("Files.get_all_results() should raise ValueError") + finally: + os.rename(f"{Folders.results}.test", Folders.results) + # check with hidden_results + os.rename(Folders.hidden_results, f"{Folders.hidden_results}.test") + try: + Files().get_all_results(hidden=True) + except ValueError: + pass + else: + self.fail("Files.get_all_results() should raise ValueError") + finally: + os.rename(f"{Folders.hidden_results}.test", Folders.hidden_results) + def test_Symbols(self): self.assertEqual(Symbols.check_mark, "\N{heavy check mark}") @@ -156,6 +179,13 @@ class UtilTest(unittest.TestCase): self.assertDictEqual(computed, expected) def test_EnvDefault(self): + expected = { + "score": "accuracy", + "platform": "iMac27", + "n_folds": 5, + "model": "ODTE", + "stratified": "0", + } ap = argparse.ArgumentParser() ap.add_argument( "-s", @@ -178,6 +208,8 @@ class UtilTest(unittest.TestCase): ap.add_argument( "-m", "--model", + action=EnvDefault, + envvar="model", type=str, required=True, help="model name", @@ -200,20 +232,26 @@ class UtilTest(unittest.TestCase): required=True, help="Stratified", ) - # ap.add_argument( - # "--title", - # type=str, - # required=True, - # ) - # args = ap.parse_args([ - # "--title", - # "test", - # ]) - # args = ap.parse_known_args(namespace=unittest) - # computed = args.__dict__ - # for key, value in expected.items(): - # self.assertEqual(computed[key], value) - # print(computed) + ap.add_argument( + "--title", + type=str, + required=True, + ) + ap.add_argument( + "--test", + type=str, + required=False, + default=None, + ) + args = ap.parse_args( + [ + "--title", + "test", + ], + ) + computed = args.__dict__ + for key, value in expected.items(): + self.assertEqual(computed[key], value) def test_TextColor(self): self.assertEqual(TextColor.BLUE, "\033[94m")