Add some tests

This commit is contained in:
2022-05-06 17:15:24 +02:00
parent d87c7064a9
commit 3009167813
8 changed files with 128 additions and 6 deletions

View File

@@ -288,5 +288,5 @@ class Arguments:
)
return self
def parse(self):
return self.ap.parse_args()
def parse(self, args=None):
return self.ap.parse_args(args)

View File

@@ -10,8 +10,8 @@ from benchmark.Arguments import Arguments
def main():
arguments = Arguments()
arguments.xset("number").xset("model", required=False).xset("score")
arguments.xset("hidden").xset("nan").xset("key")
arguments.xset("number").xset("model", required=False).xset("key")
arguments.xset("hidden").xset("nan").xset("score", required=False)
args = arguments.parse()
data = Summary(hidden=args.hidden)
data.acquire()

View File

@@ -0,0 +1,88 @@
from argparse import ArgumentError
from io import StringIO
from unittest.mock import patch
from .TestBase import TestBase
from ..Arguments import Arguments, ALL_METRICS
class ArgumentsTest(TestBase):
def build_args(self):
arguments = Arguments()
arguments.xset("n_folds").xset("model", mandatory=True)
arguments.xset("key", required=True)
return arguments
def test_build_hyperparams_file(self):
expected_metrics = (
"accuracy",
"f1_macro",
"f1_micro",
"f1_weighted",
"roc_auc_ovr",
)
self.assertSequenceEqual(ALL_METRICS, expected_metrics)
def test_parameters(self):
expected_parameters = {
"best": ("-b", "--best"),
"color": ("-c", "--color"),
"compare": ("-c", "--compare"),
"dataset": ("-d", "--dataset"),
"excel": ("-x", "--excel"),
"file": ("-f", "--file"),
"grid": ("-g", "--grid"),
"grid_paramfile": ("-g", "--grid_paramfile"),
"hidden": ("--hidden",),
"hyperparameters": ("-p", "--hyperparameters"),
"key": ("-k", "--key"),
"lose": ("-l", "--lose"),
"model": ("-m", "--model"),
"model1": ("-m1", "--model1"),
"model2": ("-m2", "--model2"),
"nan": ("--nan",),
"number": ("-n", "--number"),
"n_folds": ("-n", "--n_folds"),
"paramfile": ("-f", "--paramfile"),
"platform": ("-P", "--platform"),
"quiet": ("-q", "--quiet"),
"report": ("-r", "--report"),
"score": ("-s", "--score"),
"sql": ("-q", "--sql"),
"stratified": ("-t", "--stratified"),
"tex_output": ("-t", "--tex-output"),
"title": ("--title",),
"win": ("-w", "--win"),
}
arg = Arguments()
for key, value in expected_parameters.items():
self.assertSequenceEqual(arg.parameters[key][0], value, key)
def test_xset(self):
arguments = self.build_args()
test_args = ["-n", "3", "--model", "SVC", "-k", "metric"]
args = arguments.parse(test_args)
self.assertEqual(args.n_folds, 3)
self.assertEqual(args.model, "SVC")
self.assertEqual(args.key, "metric")
@patch("sys.stderr", new_callable=StringIO)
def test_xset_mandatory(self, mock_stderr):
arguments = self.build_args()
test_args = ["-n", "3", "-k", "date"]
with self.assertRaises(SystemExit):
arguments.parse(test_args)
self.assertRegexpMatches(
mock_stderr.getvalue(),
r"error: the following arguments are required: -m/--model",
)
@patch("sys.stderr", new_callable=StringIO)
def test_xset_required(self, mock_stderr):
arguments = self.build_args()
test_args = ["-n", "3", "-m", "SVC"]
with self.assertRaises(SystemExit):
arguments.parse(test_args)
self.assertRegexpMatches(
mock_stderr.getvalue(),
r"error: the following arguments are required: -k/--key",
)

View File

@@ -62,3 +62,12 @@ class BestResultTest(TestBase):
best.fill({}),
{"balance-scale": (0.0, {}, ""), "balloons": (0.0, {}, "")},
)
def test_build_error(self):
dt = Datasets()
model = "SVC"
best = BestResults(
score="accuracy", model=model, datasets=dt, quiet=True
)
with self.assertRaises(ValueError):
best.build()

View File

@@ -1,7 +1,7 @@
from io import StringIO
from unittest.mock import patch
from .TestBase import TestBase
from ..Results import Report, BaseReport, ReportBest
from ..Results import Report, BaseReport, ReportBest, ReportDatasets
from ..Utils import Symbols
@@ -79,3 +79,9 @@ class ReportTest(TestBase):
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "report_best.test")
@patch("sys.stdout", new_callable=StringIO)
def test_report_datasets(self, mock_output):
report = ReportDatasets()
report.report()
self.check_output_file(mock_output, "report_datasets.test")

View File

@@ -215,3 +215,17 @@ class SummaryTest(TestBase):
with patch(self.output, new=StringIO()) as fake_out:
report.show_top()
self.check_output_file(fake_out, "summary_show_top.test")
@patch("sys.stdout", new_callable=StringIO)
def test_show_top_no_data(self, fake_out):
report = Summary()
report.acquire()
report.show_top(score="f1-macro")
self.assertEqual(fake_out.getvalue(), "** No results found **\n")
def test_no_data(self):
report = Summary()
report.acquire()
with self.assertRaises(ValueError) as msg:
report.list_results(score="f1-macro", model="STree")
self.assertEqual(str(msg.exception), "** No results found **")

View File

@@ -10,6 +10,7 @@ from .SQL_test import SQLTest
from .Benchmark_test import BenchmarkTest
from .Summary_test import SummaryTest
from .PairCheck_test import PairCheckTest
from .Arguments_test import ArgumentsTest
all = [
"UtilTest",
@@ -24,5 +25,5 @@ all = [
"BenchmarkTest",
"SummaryTest",
"PairCheckTest",
"be_list",
"ArgumentsTest",
]

View File

@@ -0,0 +1,4 @@
Dataset Samp. Feat. Cls Balance
============================== ===== ===== === ========================================
balance-scale 625 4 3 7.84%/ 46.08%/ 46.08%
balloons 16 4 2 56.25%/ 43.75%