Add some tests

This commit is contained in:
2022-05-06 17:15:24 +02:00
parent d87c7064a9
commit 3009167813
8 changed files with 128 additions and 6 deletions

View File

@@ -288,5 +288,5 @@ class Arguments:
) )
return self return self
def parse(self): def parse(self, args=None):
return self.ap.parse_args() return self.ap.parse_args(args)

View File

@@ -10,8 +10,8 @@ from benchmark.Arguments import Arguments
def main(): def main():
arguments = Arguments() arguments = Arguments()
arguments.xset("number").xset("model", required=False).xset("score") arguments.xset("number").xset("model", required=False).xset("key")
arguments.xset("hidden").xset("nan").xset("key") arguments.xset("hidden").xset("nan").xset("score", required=False)
args = arguments.parse() args = arguments.parse()
data = Summary(hidden=args.hidden) data = Summary(hidden=args.hidden)
data.acquire() data.acquire()

View File

@@ -0,0 +1,88 @@
from argparse import ArgumentError
from io import StringIO
from unittest.mock import patch
from .TestBase import TestBase
from ..Arguments import Arguments, ALL_METRICS
class ArgumentsTest(TestBase):
def build_args(self):
arguments = Arguments()
arguments.xset("n_folds").xset("model", mandatory=True)
arguments.xset("key", required=True)
return arguments
def test_build_hyperparams_file(self):
expected_metrics = (
"accuracy",
"f1_macro",
"f1_micro",
"f1_weighted",
"roc_auc_ovr",
)
self.assertSequenceEqual(ALL_METRICS, expected_metrics)
def test_parameters(self):
expected_parameters = {
"best": ("-b", "--best"),
"color": ("-c", "--color"),
"compare": ("-c", "--compare"),
"dataset": ("-d", "--dataset"),
"excel": ("-x", "--excel"),
"file": ("-f", "--file"),
"grid": ("-g", "--grid"),
"grid_paramfile": ("-g", "--grid_paramfile"),
"hidden": ("--hidden",),
"hyperparameters": ("-p", "--hyperparameters"),
"key": ("-k", "--key"),
"lose": ("-l", "--lose"),
"model": ("-m", "--model"),
"model1": ("-m1", "--model1"),
"model2": ("-m2", "--model2"),
"nan": ("--nan",),
"number": ("-n", "--number"),
"n_folds": ("-n", "--n_folds"),
"paramfile": ("-f", "--paramfile"),
"platform": ("-P", "--platform"),
"quiet": ("-q", "--quiet"),
"report": ("-r", "--report"),
"score": ("-s", "--score"),
"sql": ("-q", "--sql"),
"stratified": ("-t", "--stratified"),
"tex_output": ("-t", "--tex-output"),
"title": ("--title",),
"win": ("-w", "--win"),
}
arg = Arguments()
for key, value in expected_parameters.items():
self.assertSequenceEqual(arg.parameters[key][0], value, key)
def test_xset(self):
arguments = self.build_args()
test_args = ["-n", "3", "--model", "SVC", "-k", "metric"]
args = arguments.parse(test_args)
self.assertEqual(args.n_folds, 3)
self.assertEqual(args.model, "SVC")
self.assertEqual(args.key, "metric")
@patch("sys.stderr", new_callable=StringIO)
def test_xset_mandatory(self, mock_stderr):
arguments = self.build_args()
test_args = ["-n", "3", "-k", "date"]
with self.assertRaises(SystemExit):
arguments.parse(test_args)
self.assertRegexpMatches(
mock_stderr.getvalue(),
r"error: the following arguments are required: -m/--model",
)
@patch("sys.stderr", new_callable=StringIO)
def test_xset_required(self, mock_stderr):
arguments = self.build_args()
test_args = ["-n", "3", "-m", "SVC"]
with self.assertRaises(SystemExit):
arguments.parse(test_args)
self.assertRegexpMatches(
mock_stderr.getvalue(),
r"error: the following arguments are required: -k/--key",
)

View File

@@ -62,3 +62,12 @@ class BestResultTest(TestBase):
best.fill({}), best.fill({}),
{"balance-scale": (0.0, {}, ""), "balloons": (0.0, {}, "")}, {"balance-scale": (0.0, {}, ""), "balloons": (0.0, {}, "")},
) )
def test_build_error(self):
dt = Datasets()
model = "SVC"
best = BestResults(
score="accuracy", model=model, datasets=dt, quiet=True
)
with self.assertRaises(ValueError):
best.build()

View File

@@ -1,7 +1,7 @@
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from .TestBase import TestBase from .TestBase import TestBase
from ..Results import Report, BaseReport, ReportBest from ..Results import Report, BaseReport, ReportBest, ReportDatasets
from ..Utils import Symbols from ..Utils import Symbols
@@ -79,3 +79,9 @@ class ReportTest(TestBase):
with patch(self.output, new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
self.check_output_file(fake_out, "report_best.test") self.check_output_file(fake_out, "report_best.test")
@patch("sys.stdout", new_callable=StringIO)
def test_report_datasets(self, mock_output):
report = ReportDatasets()
report.report()
self.check_output_file(mock_output, "report_datasets.test")

View File

@@ -215,3 +215,17 @@ class SummaryTest(TestBase):
with patch(self.output, new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.show_top() report.show_top()
self.check_output_file(fake_out, "summary_show_top.test") self.check_output_file(fake_out, "summary_show_top.test")
@patch("sys.stdout", new_callable=StringIO)
def test_show_top_no_data(self, fake_out):
report = Summary()
report.acquire()
report.show_top(score="f1-macro")
self.assertEqual(fake_out.getvalue(), "** No results found **\n")
def test_no_data(self):
report = Summary()
report.acquire()
with self.assertRaises(ValueError) as msg:
report.list_results(score="f1-macro", model="STree")
self.assertEqual(str(msg.exception), "** No results found **")

View File

@@ -10,6 +10,7 @@ from .SQL_test import SQLTest
from .Benchmark_test import BenchmarkTest from .Benchmark_test import BenchmarkTest
from .Summary_test import SummaryTest from .Summary_test import SummaryTest
from .PairCheck_test import PairCheckTest from .PairCheck_test import PairCheckTest
from .Arguments_test import ArgumentsTest
all = [ all = [
"UtilTest", "UtilTest",
@@ -24,5 +25,5 @@ all = [
"BenchmarkTest", "BenchmarkTest",
"SummaryTest", "SummaryTest",
"PairCheckTest", "PairCheckTest",
"be_list", "ArgumentsTest",
] ]

View File

@@ -0,0 +1,4 @@
Dataset Samp. Feat. Cls Balance
============================== ===== ===== === ========================================
balance-scale 625 4 3 7.84%/ 46.08%/ 46.08%
balloons 16 4 2 56.25%/ 43.75%