From 3b214773ff6e6b55db2bade8dbc057ab13d572d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Fri, 6 May 2022 23:05:43 +0200 Subject: [PATCH] Refactor scripts testing --- benchmark/Arguments.py | 8 ++-- benchmark/Experiments.py | 3 +- benchmark/scripts/be_list.py | 4 +- benchmark/scripts/be_pair_check.py | 4 +- benchmark/tests/Arguments_test.py | 8 ++-- benchmark/tests/Scripts_test.py | 52 ---------------------- benchmark/tests/TestBase.py | 27 +++++++++++ benchmark/tests/__init__.py | 6 ++- benchmark/tests/scripts/List_test.py | 21 +++++++++ benchmark/tests/scripts/Pair_check_test.py | 27 +++++++++++ 10 files changed, 93 insertions(+), 67 deletions(-) delete mode 100644 benchmark/tests/Scripts_test.py create mode 100644 benchmark/tests/scripts/List_test.py create mode 100644 benchmark/tests/scripts/Pair_check_test.py diff --git a/benchmark/Arguments.py b/benchmark/Arguments.py index d04dabd..1446099 100644 --- a/benchmark/Arguments.py +++ b/benchmark/Arguments.py @@ -4,10 +4,10 @@ from .Utils import Files ALL_METRICS = ( "accuracy", - "f1_macro", - "f1_micro", - "f1_weighted", - "roc_auc_ovr", + "f1-macro", + "f1-micro", + "f1-weighted", + "roc-auc-ovr", ) diff --git a/benchmark/Experiments.py b/benchmark/Experiments.py index 0b289e8..f2a8793 100644 --- a/benchmark/Experiments.py +++ b/benchmark/Experiments.py @@ -70,7 +70,8 @@ class DatasetsSurcov: ) data.dropna(axis=0, how="any", inplace=True) self.columns = data.columns - X = data.drop("class", axis=1).to_numpy() + col_list = ["class"] + X = data.drop(col_list, axis=1).to_numpy() y = data["class"].to_numpy() return X, y diff --git a/benchmark/scripts/be_list.py b/benchmark/scripts/be_list.py index 41ded02..cb31cb0 100755 --- a/benchmark/scripts/be_list.py +++ b/benchmark/scripts/be_list.py @@ -8,11 +8,11 @@ from benchmark.Arguments import Arguments """ -def main(): +def main(args_test=None): arguments = Arguments() arguments.xset("number").xset("model", required=False).xset("key") arguments.xset("hidden").xset("nan").xset("score", required=False) - args = arguments.parse() + args = arguments.parse(args_test) data = Summary(hidden=args.hidden) data.acquire() try: diff --git a/benchmark/scripts/be_pair_check.py b/benchmark/scripts/be_pair_check.py index b005db0..2c83639 100755 --- a/benchmark/scripts/be_pair_check.py +++ b/benchmark/scripts/be_pair_check.py @@ -6,11 +6,11 @@ from benchmark.Arguments import Arguments """ -def main(argx=None): +def main(args_test=None): arguments = Arguments() arguments.xset("score").xset("win").xset("model1").xset("model2") arguments.xset("lose") - args = arguments.parse(argx) + args = arguments.parse(args_test) pair_check = PairCheck( args.score, args.model1, diff --git a/benchmark/tests/Arguments_test.py b/benchmark/tests/Arguments_test.py index 0dbbedf..4e9768f 100644 --- a/benchmark/tests/Arguments_test.py +++ b/benchmark/tests/Arguments_test.py @@ -14,10 +14,10 @@ class ArgumentsTest(TestBase): def test_build_hyperparams_file(self): expected_metrics = ( "accuracy", - "f1_macro", - "f1_micro", - "f1_weighted", - "roc_auc_ovr", + "f1-macro", + "f1-micro", + "f1-weighted", + "roc-auc-ovr", ) self.assertSequenceEqual(ALL_METRICS, expected_metrics) diff --git a/benchmark/tests/Scripts_test.py b/benchmark/tests/Scripts_test.py deleted file mode 100644 index 9a1d21f..0000000 --- a/benchmark/tests/Scripts_test.py +++ /dev/null @@ -1,52 +0,0 @@ -import os -import sys -import glob -import pathlib -from importlib import import_module -from io import StringIO -from unittest.mock import patch -from .TestBase import TestBase - - -class ScriptsTest(TestBase): - def setUp(self): - self.scripts_folder = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "..", "scripts" - ) - sys.path.append(self.scripts_folder) - - def search_script(self, name): - py_files = glob.glob(os.path.join(self.scripts_folder, "*.py")) - for py_file in py_files: - module_name = pathlib.Path(py_file).stem - if name == module_name: - module = import_module(module_name) - return module - - @patch("sys.stdout", new_callable=StringIO) - @patch("sys.stderr", new_callable=StringIO) - def execute_script(self, script, args, stderr, stdout): - module = self.search_script(script) - module.main(args) - return stdout, stderr - - def test_be_pair_check(self): - stdout, stderr = self.execute_script( - "be_pair_check", ["-m1", "ODTE", "-m2", "STree"] - ) - self.assertEqual(stderr.getvalue(), "") - self.check_output_file(stdout, "paircheck.test") - - def test_be_pair_check_no_data_a(self): - stdout, stderr = self.execute_script( - "be_pair_check", ["-m1", "SVC", "-m2", "ODTE"] - ) - self.assertEqual(stderr.getvalue(), "") - self.assertEqual(stdout.getvalue(), "** No results found **\n") - - def test_be_pair_check_no_data_b(self): - stdout, stderr = self.execute_script( - "be_pair_check", ["-m1", "STree", "-m2", "SVC"] - ) - self.assertEqual(stderr.getvalue(), "") - self.assertEqual(stdout.getvalue(), "** No results found **\n") diff --git a/benchmark/tests/TestBase.py b/benchmark/tests/TestBase.py index 13778b9..277c70a 100644 --- a/benchmark/tests/TestBase.py +++ b/benchmark/tests/TestBase.py @@ -1,6 +1,12 @@ import os +import glob +import pathlib +import sys import csv import unittest +from importlib import import_module +from io import StringIO +from unittest.mock import patch class TestBase(unittest.TestCase): @@ -48,3 +54,24 @@ class TestBase(unittest.TestCase): with open(os.path.join(self.test_files, expected_file)) as f: expected = f.read() self.assertEqual(computed, expected) + + def prepare_scripts_env(self): + self.scripts_folder = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "scripts" + ) + sys.path.append(self.scripts_folder) + + def search_script(self, name): + py_files = glob.glob(os.path.join(self.scripts_folder, "*.py")) + for py_file in py_files: + module_name = pathlib.Path(py_file).stem + if name == module_name: + module = import_module(module_name) + return module + + @patch("sys.stdout", new_callable=StringIO) + @patch("sys.stderr", new_callable=StringIO) + def execute_script(self, script, args, stderr, stdout): + module = self.search_script(script) + module.main(args) + return stdout, stderr diff --git a/benchmark/tests/__init__.py b/benchmark/tests/__init__.py index 9c81aae..fbbca7f 100644 --- a/benchmark/tests/__init__.py +++ b/benchmark/tests/__init__.py @@ -11,7 +11,8 @@ from .Benchmark_test import BenchmarkTest from .Summary_test import SummaryTest from .PairCheck_test import PairCheckTest from .Arguments_test import ArgumentsTest -from .Scripts_test import ScriptsTest +from .scripts.Pair_check_test import BePairCheckTest +from .scripts.List_test import ListTest all = [ "UtilTest", @@ -27,5 +28,6 @@ all = [ "SummaryTest", "PairCheckTest", "ArgumentsTest", - "ScriptsTest", + "BePairCheckTest", + "ListTest", ] diff --git a/benchmark/tests/scripts/List_test.py b/benchmark/tests/scripts/List_test.py new file mode 100644 index 0000000..ccbec25 --- /dev/null +++ b/benchmark/tests/scripts/List_test.py @@ -0,0 +1,21 @@ +from ..TestBase import TestBase + + +class ListTest(TestBase): + def setUp(self): + self.prepare_scripts_env() + + def test_be_list(self): + stdout, stderr = self.execute_script("be_list", ["-m", "STree"]) + self.assertEqual(stderr.getvalue(), "") + self.check_output_file(stdout, "summary_list_model.test") + + def test_be_list_no_data(self): + stdout, stderr = self.execute_script( + "be_list", ["-m", "Wodt", "-s", "f1-macro"] + ) + self.assertEqual(stderr.getvalue(), "") + self.assertEqual(stdout.getvalue(), "** No results found **\n") + + def test_be_list_nan(self): + pass diff --git a/benchmark/tests/scripts/Pair_check_test.py b/benchmark/tests/scripts/Pair_check_test.py new file mode 100644 index 0000000..60628b5 --- /dev/null +++ b/benchmark/tests/scripts/Pair_check_test.py @@ -0,0 +1,27 @@ +from ..TestBase import TestBase + + +class BePairCheckTest(TestBase): + def setUp(self): + self.prepare_scripts_env() + + def test_be_pair_check(self): + stdout, stderr = self.execute_script( + "be_pair_check", ["-m1", "ODTE", "-m2", "STree"] + ) + self.assertEqual(stderr.getvalue(), "") + self.check_output_file(stdout, "paircheck.test") + + def test_be_pair_check_no_data_a(self): + stdout, stderr = self.execute_script( + "be_pair_check", ["-m1", "SVC", "-m2", "ODTE"] + ) + self.assertEqual(stderr.getvalue(), "") + self.assertEqual(stdout.getvalue(), "** No results found **\n") + + def test_be_pair_check_no_data_b(self): + stdout, stderr = self.execute_script( + "be_pair_check", ["-m1", "STree", "-m2", "SVC"] + ) + self.assertEqual(stderr.getvalue(), "") + self.assertEqual(stdout.getvalue(), "** No results found **\n")