diff --git a/benchmark/Results.py b/benchmark/Results.py index 4ef9bb5..766079f 100644 --- a/benchmark/Results.py +++ b/benchmark/Results.py @@ -1277,6 +1277,8 @@ class Summary: if criterion is None or value is None else [x for x in haystack if x[criterion] == value] ) + if haystack == []: + raise ValueError("** No results found **") return ( sorted( haystack, diff --git a/benchmark/scripts/be_pair_check.py b/benchmark/scripts/be_pair_check.py index 21e25f2..b005db0 100755 --- a/benchmark/scripts/be_pair_check.py +++ b/benchmark/scripts/be_pair_check.py @@ -6,11 +6,11 @@ from benchmark.Arguments import Arguments """ -def main(): +def main(argx=None): arguments = Arguments() arguments.xset("score").xset("win").xset("model1").xset("model2") arguments.xset("lose") - args = arguments.parse() + args = arguments.parse(argx) pair_check = PairCheck( args.score, args.model1, @@ -18,5 +18,9 @@ def main(): args.win, args.lose, ) - pair_check.compute() - pair_check.report() + try: + pair_check.compute() + except ValueError as e: + print(str(e)) + else: + pair_check.report() diff --git a/benchmark/tests/Arguments_test.py b/benchmark/tests/Arguments_test.py index c941d6c..0dbbedf 100644 --- a/benchmark/tests/Arguments_test.py +++ b/benchmark/tests/Arguments_test.py @@ -1,4 +1,3 @@ -from argparse import ArgumentError from io import StringIO from unittest.mock import patch from .TestBase import TestBase diff --git a/benchmark/tests/PairCheck_test.py b/benchmark/tests/PairCheck_test.py index f4c7b5b..51992a5 100644 --- a/benchmark/tests/PairCheck_test.py +++ b/benchmark/tests/PairCheck_test.py @@ -21,10 +21,7 @@ class PairCheckTest(TestBase): report.compute() with patch(self.output, new=StringIO()) as fake_out: report.report() - computed = fake_out.getvalue() - with open(os.path.join(self.test_files, "paircheck.test"), "r") as f: - expected = f.read() - self.assertEqual(computed, expected) + self.check_output_file(fake_out, "paircheck.test") def test_pair_check_win(self): report = self.build_model(win=True) diff --git a/benchmark/tests/Scripts_test.py b/benchmark/tests/Scripts_test.py new file mode 100644 index 0000000..9a1d21f --- /dev/null +++ b/benchmark/tests/Scripts_test.py @@ -0,0 +1,52 @@ +import os +import sys +import glob +import pathlib +from importlib import import_module +from io import StringIO +from unittest.mock import patch +from .TestBase import TestBase + + +class ScriptsTest(TestBase): + def setUp(self): + self.scripts_folder = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "scripts" + ) + sys.path.append(self.scripts_folder) + + def search_script(self, name): + py_files = glob.glob(os.path.join(self.scripts_folder, "*.py")) + for py_file in py_files: + module_name = pathlib.Path(py_file).stem + if name == module_name: + module = import_module(module_name) + return module + + @patch("sys.stdout", new_callable=StringIO) + @patch("sys.stderr", new_callable=StringIO) + def execute_script(self, script, args, stderr, stdout): + module = self.search_script(script) + module.main(args) + return stdout, stderr + + def test_be_pair_check(self): + stdout, stderr = self.execute_script( + "be_pair_check", ["-m1", "ODTE", "-m2", "STree"] + ) + self.assertEqual(stderr.getvalue(), "") + self.check_output_file(stdout, "paircheck.test") + + def test_be_pair_check_no_data_a(self): + stdout, stderr = self.execute_script( + "be_pair_check", ["-m1", "SVC", "-m2", "ODTE"] + ) + self.assertEqual(stderr.getvalue(), "") + self.assertEqual(stdout.getvalue(), "** No results found **\n") + + def test_be_pair_check_no_data_b(self): + stdout, stderr = self.execute_script( + "be_pair_check", ["-m1", "STree", "-m2", "SVC"] + ) + self.assertEqual(stderr.getvalue(), "") + self.assertEqual(stdout.getvalue(), "** No results found **\n") diff --git a/benchmark/tests/__init__.py b/benchmark/tests/__init__.py index e998410..9c81aae 100644 --- a/benchmark/tests/__init__.py +++ b/benchmark/tests/__init__.py @@ -11,6 +11,7 @@ from .Benchmark_test import BenchmarkTest from .Summary_test import SummaryTest from .PairCheck_test import PairCheckTest from .Arguments_test import ArgumentsTest +from .Scripts_test import ScriptsTest all = [ "UtilTest", @@ -26,4 +27,5 @@ all = [ "SummaryTest", "PairCheckTest", "ArgumentsTest", + "ScriptsTest", ]