Refactor scripts testing

This commit is contained in:
2022-05-06 23:05:43 +02:00
parent bb0821c56e
commit 3b214773ff
10 changed files with 93 additions and 67 deletions

View File

@@ -4,10 +4,10 @@ from .Utils import Files
ALL_METRICS = (
"accuracy",
"f1_macro",
"f1_micro",
"f1_weighted",
"roc_auc_ovr",
"f1-macro",
"f1-micro",
"f1-weighted",
"roc-auc-ovr",
)

View File

@@ -70,7 +70,8 @@ class DatasetsSurcov:
)
data.dropna(axis=0, how="any", inplace=True)
self.columns = data.columns
X = data.drop("class", axis=1).to_numpy()
col_list = ["class"]
X = data.drop(col_list, axis=1).to_numpy()
y = data["class"].to_numpy()
return X, y

View File

@@ -8,11 +8,11 @@ from benchmark.Arguments import Arguments
"""
def main():
def main(args_test=None):
arguments = Arguments()
arguments.xset("number").xset("model", required=False).xset("key")
arguments.xset("hidden").xset("nan").xset("score", required=False)
args = arguments.parse()
args = arguments.parse(args_test)
data = Summary(hidden=args.hidden)
data.acquire()
try:

View File

@@ -6,11 +6,11 @@ from benchmark.Arguments import Arguments
"""
def main(argx=None):
def main(args_test=None):
arguments = Arguments()
arguments.xset("score").xset("win").xset("model1").xset("model2")
arguments.xset("lose")
args = arguments.parse(argx)
args = arguments.parse(args_test)
pair_check = PairCheck(
args.score,
args.model1,

View File

@@ -14,10 +14,10 @@ class ArgumentsTest(TestBase):
def test_build_hyperparams_file(self):
expected_metrics = (
"accuracy",
"f1_macro",
"f1_micro",
"f1_weighted",
"roc_auc_ovr",
"f1-macro",
"f1-micro",
"f1-weighted",
"roc-auc-ovr",
)
self.assertSequenceEqual(ALL_METRICS, expected_metrics)

View File

@@ -1,52 +0,0 @@
import os
import sys
import glob
import pathlib
from importlib import import_module
from io import StringIO
from unittest.mock import patch
from .TestBase import TestBase
class ScriptsTest(TestBase):
def setUp(self):
self.scripts_folder = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "scripts"
)
sys.path.append(self.scripts_folder)
def search_script(self, name):
py_files = glob.glob(os.path.join(self.scripts_folder, "*.py"))
for py_file in py_files:
module_name = pathlib.Path(py_file).stem
if name == module_name:
module = import_module(module_name)
return module
@patch("sys.stdout", new_callable=StringIO)
@patch("sys.stderr", new_callable=StringIO)
def execute_script(self, script, args, stderr, stdout):
module = self.search_script(script)
module.main(args)
return stdout, stderr
def test_be_pair_check(self):
stdout, stderr = self.execute_script(
"be_pair_check", ["-m1", "ODTE", "-m2", "STree"]
)
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "paircheck.test")
def test_be_pair_check_no_data_a(self):
stdout, stderr = self.execute_script(
"be_pair_check", ["-m1", "SVC", "-m2", "ODTE"]
)
self.assertEqual(stderr.getvalue(), "")
self.assertEqual(stdout.getvalue(), "** No results found **\n")
def test_be_pair_check_no_data_b(self):
stdout, stderr = self.execute_script(
"be_pair_check", ["-m1", "STree", "-m2", "SVC"]
)
self.assertEqual(stderr.getvalue(), "")
self.assertEqual(stdout.getvalue(), "** No results found **\n")

View File

@@ -1,6 +1,12 @@
import os
import glob
import pathlib
import sys
import csv
import unittest
from importlib import import_module
from io import StringIO
from unittest.mock import patch
class TestBase(unittest.TestCase):
@@ -48,3 +54,24 @@ class TestBase(unittest.TestCase):
with open(os.path.join(self.test_files, expected_file)) as f:
expected = f.read()
self.assertEqual(computed, expected)
def prepare_scripts_env(self):
self.scripts_folder = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "scripts"
)
sys.path.append(self.scripts_folder)
def search_script(self, name):
py_files = glob.glob(os.path.join(self.scripts_folder, "*.py"))
for py_file in py_files:
module_name = pathlib.Path(py_file).stem
if name == module_name:
module = import_module(module_name)
return module
@patch("sys.stdout", new_callable=StringIO)
@patch("sys.stderr", new_callable=StringIO)
def execute_script(self, script, args, stderr, stdout):
module = self.search_script(script)
module.main(args)
return stdout, stderr

View File

@@ -11,7 +11,8 @@ from .Benchmark_test import BenchmarkTest
from .Summary_test import SummaryTest
from .PairCheck_test import PairCheckTest
from .Arguments_test import ArgumentsTest
from .Scripts_test import ScriptsTest
from .scripts.Pair_check_test import BePairCheckTest
from .scripts.List_test import ListTest
all = [
"UtilTest",
@@ -27,5 +28,6 @@ all = [
"SummaryTest",
"PairCheckTest",
"ArgumentsTest",
"ScriptsTest",
"BePairCheckTest",
"ListTest",
]

View File

@@ -0,0 +1,21 @@
from ..TestBase import TestBase
class ListTest(TestBase):
def setUp(self):
self.prepare_scripts_env()
def test_be_list(self):
stdout, stderr = self.execute_script("be_list", ["-m", "STree"])
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "summary_list_model.test")
def test_be_list_no_data(self):
stdout, stderr = self.execute_script(
"be_list", ["-m", "Wodt", "-s", "f1-macro"]
)
self.assertEqual(stderr.getvalue(), "")
self.assertEqual(stdout.getvalue(), "** No results found **\n")
def test_be_list_nan(self):
pass

View File

@@ -0,0 +1,27 @@
from ..TestBase import TestBase
class BePairCheckTest(TestBase):
def setUp(self):
self.prepare_scripts_env()
def test_be_pair_check(self):
stdout, stderr = self.execute_script(
"be_pair_check", ["-m1", "ODTE", "-m2", "STree"]
)
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "paircheck.test")
def test_be_pair_check_no_data_a(self):
stdout, stderr = self.execute_script(
"be_pair_check", ["-m1", "SVC", "-m2", "ODTE"]
)
self.assertEqual(stderr.getvalue(), "")
self.assertEqual(stdout.getvalue(), "** No results found **\n")
def test_be_pair_check_no_data_b(self):
stdout, stderr = self.execute_script(
"be_pair_check", ["-m1", "STree", "-m2", "SVC"]
)
self.assertEqual(stderr.getvalue(), "")
self.assertEqual(stdout.getvalue(), "** No results found **\n")