mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-18 00:45:54 +00:00
Refactor scripts testing
This commit is contained in:
@@ -4,10 +4,10 @@ from .Utils import Files
|
|||||||
|
|
||||||
ALL_METRICS = (
|
ALL_METRICS = (
|
||||||
"accuracy",
|
"accuracy",
|
||||||
"f1_macro",
|
"f1-macro",
|
||||||
"f1_micro",
|
"f1-micro",
|
||||||
"f1_weighted",
|
"f1-weighted",
|
||||||
"roc_auc_ovr",
|
"roc-auc-ovr",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -70,7 +70,8 @@ class DatasetsSurcov:
|
|||||||
)
|
)
|
||||||
data.dropna(axis=0, how="any", inplace=True)
|
data.dropna(axis=0, how="any", inplace=True)
|
||||||
self.columns = data.columns
|
self.columns = data.columns
|
||||||
X = data.drop("class", axis=1).to_numpy()
|
col_list = ["class"]
|
||||||
|
X = data.drop(col_list, axis=1).to_numpy()
|
||||||
y = data["class"].to_numpy()
|
y = data["class"].to_numpy()
|
||||||
return X, y
|
return X, y
|
||||||
|
|
||||||
|
@@ -8,11 +8,11 @@ from benchmark.Arguments import Arguments
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(args_test=None):
|
||||||
arguments = Arguments()
|
arguments = Arguments()
|
||||||
arguments.xset("number").xset("model", required=False).xset("key")
|
arguments.xset("number").xset("model", required=False).xset("key")
|
||||||
arguments.xset("hidden").xset("nan").xset("score", required=False)
|
arguments.xset("hidden").xset("nan").xset("score", required=False)
|
||||||
args = arguments.parse()
|
args = arguments.parse(args_test)
|
||||||
data = Summary(hidden=args.hidden)
|
data = Summary(hidden=args.hidden)
|
||||||
data.acquire()
|
data.acquire()
|
||||||
try:
|
try:
|
||||||
|
@@ -6,11 +6,11 @@ from benchmark.Arguments import Arguments
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def main(argx=None):
|
def main(args_test=None):
|
||||||
arguments = Arguments()
|
arguments = Arguments()
|
||||||
arguments.xset("score").xset("win").xset("model1").xset("model2")
|
arguments.xset("score").xset("win").xset("model1").xset("model2")
|
||||||
arguments.xset("lose")
|
arguments.xset("lose")
|
||||||
args = arguments.parse(argx)
|
args = arguments.parse(args_test)
|
||||||
pair_check = PairCheck(
|
pair_check = PairCheck(
|
||||||
args.score,
|
args.score,
|
||||||
args.model1,
|
args.model1,
|
||||||
|
@@ -14,10 +14,10 @@ class ArgumentsTest(TestBase):
|
|||||||
def test_build_hyperparams_file(self):
|
def test_build_hyperparams_file(self):
|
||||||
expected_metrics = (
|
expected_metrics = (
|
||||||
"accuracy",
|
"accuracy",
|
||||||
"f1_macro",
|
"f1-macro",
|
||||||
"f1_micro",
|
"f1-micro",
|
||||||
"f1_weighted",
|
"f1-weighted",
|
||||||
"roc_auc_ovr",
|
"roc-auc-ovr",
|
||||||
)
|
)
|
||||||
self.assertSequenceEqual(ALL_METRICS, expected_metrics)
|
self.assertSequenceEqual(ALL_METRICS, expected_metrics)
|
||||||
|
|
||||||
|
@@ -1,52 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import glob
|
|
||||||
import pathlib
|
|
||||||
from importlib import import_module
|
|
||||||
from io import StringIO
|
|
||||||
from unittest.mock import patch
|
|
||||||
from .TestBase import TestBase
|
|
||||||
|
|
||||||
|
|
||||||
class ScriptsTest(TestBase):
|
|
||||||
def setUp(self):
|
|
||||||
self.scripts_folder = os.path.join(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)), "..", "scripts"
|
|
||||||
)
|
|
||||||
sys.path.append(self.scripts_folder)
|
|
||||||
|
|
||||||
def search_script(self, name):
|
|
||||||
py_files = glob.glob(os.path.join(self.scripts_folder, "*.py"))
|
|
||||||
for py_file in py_files:
|
|
||||||
module_name = pathlib.Path(py_file).stem
|
|
||||||
if name == module_name:
|
|
||||||
module = import_module(module_name)
|
|
||||||
return module
|
|
||||||
|
|
||||||
@patch("sys.stdout", new_callable=StringIO)
|
|
||||||
@patch("sys.stderr", new_callable=StringIO)
|
|
||||||
def execute_script(self, script, args, stderr, stdout):
|
|
||||||
module = self.search_script(script)
|
|
||||||
module.main(args)
|
|
||||||
return stdout, stderr
|
|
||||||
|
|
||||||
def test_be_pair_check(self):
|
|
||||||
stdout, stderr = self.execute_script(
|
|
||||||
"be_pair_check", ["-m1", "ODTE", "-m2", "STree"]
|
|
||||||
)
|
|
||||||
self.assertEqual(stderr.getvalue(), "")
|
|
||||||
self.check_output_file(stdout, "paircheck.test")
|
|
||||||
|
|
||||||
def test_be_pair_check_no_data_a(self):
|
|
||||||
stdout, stderr = self.execute_script(
|
|
||||||
"be_pair_check", ["-m1", "SVC", "-m2", "ODTE"]
|
|
||||||
)
|
|
||||||
self.assertEqual(stderr.getvalue(), "")
|
|
||||||
self.assertEqual(stdout.getvalue(), "** No results found **\n")
|
|
||||||
|
|
||||||
def test_be_pair_check_no_data_b(self):
|
|
||||||
stdout, stderr = self.execute_script(
|
|
||||||
"be_pair_check", ["-m1", "STree", "-m2", "SVC"]
|
|
||||||
)
|
|
||||||
self.assertEqual(stderr.getvalue(), "")
|
|
||||||
self.assertEqual(stdout.getvalue(), "** No results found **\n")
|
|
@@ -1,6 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
|
import glob
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
import csv
|
import csv
|
||||||
import unittest
|
import unittest
|
||||||
|
from importlib import import_module
|
||||||
|
from io import StringIO
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
||||||
class TestBase(unittest.TestCase):
|
class TestBase(unittest.TestCase):
|
||||||
@@ -48,3 +54,24 @@ class TestBase(unittest.TestCase):
|
|||||||
with open(os.path.join(self.test_files, expected_file)) as f:
|
with open(os.path.join(self.test_files, expected_file)) as f:
|
||||||
expected = f.read()
|
expected = f.read()
|
||||||
self.assertEqual(computed, expected)
|
self.assertEqual(computed, expected)
|
||||||
|
|
||||||
|
def prepare_scripts_env(self):
|
||||||
|
self.scripts_folder = os.path.join(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)), "..", "scripts"
|
||||||
|
)
|
||||||
|
sys.path.append(self.scripts_folder)
|
||||||
|
|
||||||
|
def search_script(self, name):
|
||||||
|
py_files = glob.glob(os.path.join(self.scripts_folder, "*.py"))
|
||||||
|
for py_file in py_files:
|
||||||
|
module_name = pathlib.Path(py_file).stem
|
||||||
|
if name == module_name:
|
||||||
|
module = import_module(module_name)
|
||||||
|
return module
|
||||||
|
|
||||||
|
@patch("sys.stdout", new_callable=StringIO)
|
||||||
|
@patch("sys.stderr", new_callable=StringIO)
|
||||||
|
def execute_script(self, script, args, stderr, stdout):
|
||||||
|
module = self.search_script(script)
|
||||||
|
module.main(args)
|
||||||
|
return stdout, stderr
|
||||||
|
@@ -11,7 +11,8 @@ from .Benchmark_test import BenchmarkTest
|
|||||||
from .Summary_test import SummaryTest
|
from .Summary_test import SummaryTest
|
||||||
from .PairCheck_test import PairCheckTest
|
from .PairCheck_test import PairCheckTest
|
||||||
from .Arguments_test import ArgumentsTest
|
from .Arguments_test import ArgumentsTest
|
||||||
from .Scripts_test import ScriptsTest
|
from .scripts.Pair_check_test import BePairCheckTest
|
||||||
|
from .scripts.List_test import ListTest
|
||||||
|
|
||||||
all = [
|
all = [
|
||||||
"UtilTest",
|
"UtilTest",
|
||||||
@@ -27,5 +28,6 @@ all = [
|
|||||||
"SummaryTest",
|
"SummaryTest",
|
||||||
"PairCheckTest",
|
"PairCheckTest",
|
||||||
"ArgumentsTest",
|
"ArgumentsTest",
|
||||||
"ScriptsTest",
|
"BePairCheckTest",
|
||||||
|
"ListTest",
|
||||||
]
|
]
|
||||||
|
21
benchmark/tests/scripts/List_test.py
Normal file
21
benchmark/tests/scripts/List_test.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from ..TestBase import TestBase
|
||||||
|
|
||||||
|
|
||||||
|
class ListTest(TestBase):
|
||||||
|
def setUp(self):
|
||||||
|
self.prepare_scripts_env()
|
||||||
|
|
||||||
|
def test_be_list(self):
|
||||||
|
stdout, stderr = self.execute_script("be_list", ["-m", "STree"])
|
||||||
|
self.assertEqual(stderr.getvalue(), "")
|
||||||
|
self.check_output_file(stdout, "summary_list_model.test")
|
||||||
|
|
||||||
|
def test_be_list_no_data(self):
|
||||||
|
stdout, stderr = self.execute_script(
|
||||||
|
"be_list", ["-m", "Wodt", "-s", "f1-macro"]
|
||||||
|
)
|
||||||
|
self.assertEqual(stderr.getvalue(), "")
|
||||||
|
self.assertEqual(stdout.getvalue(), "** No results found **\n")
|
||||||
|
|
||||||
|
def test_be_list_nan(self):
|
||||||
|
pass
|
27
benchmark/tests/scripts/Pair_check_test.py
Normal file
27
benchmark/tests/scripts/Pair_check_test.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from ..TestBase import TestBase
|
||||||
|
|
||||||
|
|
||||||
|
class BePairCheckTest(TestBase):
|
||||||
|
def setUp(self):
|
||||||
|
self.prepare_scripts_env()
|
||||||
|
|
||||||
|
def test_be_pair_check(self):
|
||||||
|
stdout, stderr = self.execute_script(
|
||||||
|
"be_pair_check", ["-m1", "ODTE", "-m2", "STree"]
|
||||||
|
)
|
||||||
|
self.assertEqual(stderr.getvalue(), "")
|
||||||
|
self.check_output_file(stdout, "paircheck.test")
|
||||||
|
|
||||||
|
def test_be_pair_check_no_data_a(self):
|
||||||
|
stdout, stderr = self.execute_script(
|
||||||
|
"be_pair_check", ["-m1", "SVC", "-m2", "ODTE"]
|
||||||
|
)
|
||||||
|
self.assertEqual(stderr.getvalue(), "")
|
||||||
|
self.assertEqual(stdout.getvalue(), "** No results found **\n")
|
||||||
|
|
||||||
|
def test_be_pair_check_no_data_b(self):
|
||||||
|
stdout, stderr = self.execute_script(
|
||||||
|
"be_pair_check", ["-m1", "STree", "-m2", "SVC"]
|
||||||
|
)
|
||||||
|
self.assertEqual(stderr.getvalue(), "")
|
||||||
|
self.assertEqual(stdout.getvalue(), "** No results found **\n")
|
Reference in New Issue
Block a user