mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-17 16:35:54 +00:00
Refactor scripts testing
This commit is contained in:
@@ -4,10 +4,10 @@ from .Utils import Files
|
||||
|
||||
ALL_METRICS = (
|
||||
"accuracy",
|
||||
"f1_macro",
|
||||
"f1_micro",
|
||||
"f1_weighted",
|
||||
"roc_auc_ovr",
|
||||
"f1-macro",
|
||||
"f1-micro",
|
||||
"f1-weighted",
|
||||
"roc-auc-ovr",
|
||||
)
|
||||
|
||||
|
||||
|
@@ -70,7 +70,8 @@ class DatasetsSurcov:
|
||||
)
|
||||
data.dropna(axis=0, how="any", inplace=True)
|
||||
self.columns = data.columns
|
||||
X = data.drop("class", axis=1).to_numpy()
|
||||
col_list = ["class"]
|
||||
X = data.drop(col_list, axis=1).to_numpy()
|
||||
y = data["class"].to_numpy()
|
||||
return X, y
|
||||
|
||||
|
@@ -8,11 +8,11 @@ from benchmark.Arguments import Arguments
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
arguments.xset("number").xset("model", required=False).xset("key")
|
||||
arguments.xset("hidden").xset("nan").xset("score", required=False)
|
||||
args = arguments.parse()
|
||||
args = arguments.parse(args_test)
|
||||
data = Summary(hidden=args.hidden)
|
||||
data.acquire()
|
||||
try:
|
||||
|
@@ -6,11 +6,11 @@ from benchmark.Arguments import Arguments
|
||||
"""
|
||||
|
||||
|
||||
def main(argx=None):
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
arguments.xset("score").xset("win").xset("model1").xset("model2")
|
||||
arguments.xset("lose")
|
||||
args = arguments.parse(argx)
|
||||
args = arguments.parse(args_test)
|
||||
pair_check = PairCheck(
|
||||
args.score,
|
||||
args.model1,
|
||||
|
@@ -14,10 +14,10 @@ class ArgumentsTest(TestBase):
|
||||
def test_build_hyperparams_file(self):
|
||||
expected_metrics = (
|
||||
"accuracy",
|
||||
"f1_macro",
|
||||
"f1_micro",
|
||||
"f1_weighted",
|
||||
"roc_auc_ovr",
|
||||
"f1-macro",
|
||||
"f1-micro",
|
||||
"f1-weighted",
|
||||
"roc-auc-ovr",
|
||||
)
|
||||
self.assertSequenceEqual(ALL_METRICS, expected_metrics)
|
||||
|
||||
|
@@ -1,52 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import pathlib
|
||||
from importlib import import_module
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
from .TestBase import TestBase
|
||||
|
||||
|
||||
class ScriptsTest(TestBase):
|
||||
def setUp(self):
|
||||
self.scripts_folder = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "..", "scripts"
|
||||
)
|
||||
sys.path.append(self.scripts_folder)
|
||||
|
||||
def search_script(self, name):
|
||||
py_files = glob.glob(os.path.join(self.scripts_folder, "*.py"))
|
||||
for py_file in py_files:
|
||||
module_name = pathlib.Path(py_file).stem
|
||||
if name == module_name:
|
||||
module = import_module(module_name)
|
||||
return module
|
||||
|
||||
@patch("sys.stdout", new_callable=StringIO)
|
||||
@patch("sys.stderr", new_callable=StringIO)
|
||||
def execute_script(self, script, args, stderr, stdout):
|
||||
module = self.search_script(script)
|
||||
module.main(args)
|
||||
return stdout, stderr
|
||||
|
||||
def test_be_pair_check(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_pair_check", ["-m1", "ODTE", "-m2", "STree"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "paircheck.test")
|
||||
|
||||
def test_be_pair_check_no_data_a(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_pair_check", ["-m1", "SVC", "-m2", "ODTE"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.assertEqual(stdout.getvalue(), "** No results found **\n")
|
||||
|
||||
def test_be_pair_check_no_data_b(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_pair_check", ["-m1", "STree", "-m2", "SVC"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.assertEqual(stdout.getvalue(), "** No results found **\n")
|
@@ -1,6 +1,12 @@
|
||||
import os
|
||||
import glob
|
||||
import pathlib
|
||||
import sys
|
||||
import csv
|
||||
import unittest
|
||||
from importlib import import_module
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
class TestBase(unittest.TestCase):
|
||||
@@ -48,3 +54,24 @@ class TestBase(unittest.TestCase):
|
||||
with open(os.path.join(self.test_files, expected_file)) as f:
|
||||
expected = f.read()
|
||||
self.assertEqual(computed, expected)
|
||||
|
||||
def prepare_scripts_env(self):
|
||||
self.scripts_folder = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "..", "scripts"
|
||||
)
|
||||
sys.path.append(self.scripts_folder)
|
||||
|
||||
def search_script(self, name):
|
||||
py_files = glob.glob(os.path.join(self.scripts_folder, "*.py"))
|
||||
for py_file in py_files:
|
||||
module_name = pathlib.Path(py_file).stem
|
||||
if name == module_name:
|
||||
module = import_module(module_name)
|
||||
return module
|
||||
|
||||
@patch("sys.stdout", new_callable=StringIO)
|
||||
@patch("sys.stderr", new_callable=StringIO)
|
||||
def execute_script(self, script, args, stderr, stdout):
|
||||
module = self.search_script(script)
|
||||
module.main(args)
|
||||
return stdout, stderr
|
||||
|
@@ -11,7 +11,8 @@ from .Benchmark_test import BenchmarkTest
|
||||
from .Summary_test import SummaryTest
|
||||
from .PairCheck_test import PairCheckTest
|
||||
from .Arguments_test import ArgumentsTest
|
||||
from .Scripts_test import ScriptsTest
|
||||
from .scripts.Pair_check_test import BePairCheckTest
|
||||
from .scripts.List_test import ListTest
|
||||
|
||||
all = [
|
||||
"UtilTest",
|
||||
@@ -27,5 +28,6 @@ all = [
|
||||
"SummaryTest",
|
||||
"PairCheckTest",
|
||||
"ArgumentsTest",
|
||||
"ScriptsTest",
|
||||
"BePairCheckTest",
|
||||
"ListTest",
|
||||
]
|
||||
|
21
benchmark/tests/scripts/List_test.py
Normal file
21
benchmark/tests/scripts/List_test.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from ..TestBase import TestBase
|
||||
|
||||
|
||||
class ListTest(TestBase):
|
||||
def setUp(self):
|
||||
self.prepare_scripts_env()
|
||||
|
||||
def test_be_list(self):
|
||||
stdout, stderr = self.execute_script("be_list", ["-m", "STree"])
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "summary_list_model.test")
|
||||
|
||||
def test_be_list_no_data(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_list", ["-m", "Wodt", "-s", "f1-macro"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.assertEqual(stdout.getvalue(), "** No results found **\n")
|
||||
|
||||
def test_be_list_nan(self):
|
||||
pass
|
27
benchmark/tests/scripts/Pair_check_test.py
Normal file
27
benchmark/tests/scripts/Pair_check_test.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from ..TestBase import TestBase
|
||||
|
||||
|
||||
class BePairCheckTest(TestBase):
|
||||
def setUp(self):
|
||||
self.prepare_scripts_env()
|
||||
|
||||
def test_be_pair_check(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_pair_check", ["-m1", "ODTE", "-m2", "STree"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "paircheck.test")
|
||||
|
||||
def test_be_pair_check_no_data_a(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_pair_check", ["-m1", "SVC", "-m2", "ODTE"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.assertEqual(stdout.getvalue(), "** No results found **\n")
|
||||
|
||||
def test_be_pair_check_no_data_b(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_pair_check", ["-m1", "STree", "-m2", "SVC"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.assertEqual(stdout.getvalue(), "** No results found **\n")
|
Reference in New Issue
Block a user