From 3009167813c5edd43cca1b321b309008a3481333 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Fri, 6 May 2022 17:15:24 +0200 Subject: [PATCH] Add some tests --- benchmark/Arguments.py | 4 +- benchmark/scripts/be_list.py | 4 +- benchmark/tests/Arguments_test.py | 88 +++++++++++++++++++ benchmark/tests/BestResults_test.py | 9 ++ benchmark/tests/Report_test.py | 8 +- benchmark/tests/Summary_test.py | 14 +++ benchmark/tests/__init__.py | 3 +- .../tests/test_files/report_datasets.test | 4 + 8 files changed, 128 insertions(+), 6 deletions(-) create mode 100644 benchmark/tests/Arguments_test.py create mode 100644 benchmark/tests/test_files/report_datasets.test diff --git a/benchmark/Arguments.py b/benchmark/Arguments.py index 7338ce8..d04dabd 100644 --- a/benchmark/Arguments.py +++ b/benchmark/Arguments.py @@ -288,5 +288,5 @@ class Arguments: ) return self - def parse(self): - return self.ap.parse_args() + def parse(self, args=None): + return self.ap.parse_args(args) diff --git a/benchmark/scripts/be_list.py b/benchmark/scripts/be_list.py index 3552d5c..41ded02 100755 --- a/benchmark/scripts/be_list.py +++ b/benchmark/scripts/be_list.py @@ -10,8 +10,8 @@ from benchmark.Arguments import Arguments def main(): arguments = Arguments() - arguments.xset("number").xset("model", required=False).xset("score") - arguments.xset("hidden").xset("nan").xset("key") + arguments.xset("number").xset("model", required=False).xset("key") + arguments.xset("hidden").xset("nan").xset("score", required=False) args = arguments.parse() data = Summary(hidden=args.hidden) data.acquire() diff --git a/benchmark/tests/Arguments_test.py b/benchmark/tests/Arguments_test.py new file mode 100644 index 0000000..c941d6c --- /dev/null +++ b/benchmark/tests/Arguments_test.py @@ -0,0 +1,88 @@ +from argparse import ArgumentError +from io import StringIO +from unittest.mock import patch +from .TestBase import TestBase +from ..Arguments import Arguments, ALL_METRICS + + +class ArgumentsTest(TestBase): + def build_args(self): + arguments = Arguments() + arguments.xset("n_folds").xset("model", mandatory=True) + arguments.xset("key", required=True) + return arguments + + def test_build_hyperparams_file(self): + expected_metrics = ( + "accuracy", + "f1_macro", + "f1_micro", + "f1_weighted", + "roc_auc_ovr", + ) + self.assertSequenceEqual(ALL_METRICS, expected_metrics) + + def test_parameters(self): + expected_parameters = { + "best": ("-b", "--best"), + "color": ("-c", "--color"), + "compare": ("-c", "--compare"), + "dataset": ("-d", "--dataset"), + "excel": ("-x", "--excel"), + "file": ("-f", "--file"), + "grid": ("-g", "--grid"), + "grid_paramfile": ("-g", "--grid_paramfile"), + "hidden": ("--hidden",), + "hyperparameters": ("-p", "--hyperparameters"), + "key": ("-k", "--key"), + "lose": ("-l", "--lose"), + "model": ("-m", "--model"), + "model1": ("-m1", "--model1"), + "model2": ("-m2", "--model2"), + "nan": ("--nan",), + "number": ("-n", "--number"), + "n_folds": ("-n", "--n_folds"), + "paramfile": ("-f", "--paramfile"), + "platform": ("-P", "--platform"), + "quiet": ("-q", "--quiet"), + "report": ("-r", "--report"), + "score": ("-s", "--score"), + "sql": ("-q", "--sql"), + "stratified": ("-t", "--stratified"), + "tex_output": ("-t", "--tex-output"), + "title": ("--title",), + "win": ("-w", "--win"), + } + arg = Arguments() + for key, value in expected_parameters.items(): + self.assertSequenceEqual(arg.parameters[key][0], value, key) + + def test_xset(self): + arguments = self.build_args() + test_args = ["-n", "3", "--model", "SVC", "-k", "metric"] + args = arguments.parse(test_args) + self.assertEqual(args.n_folds, 3) + self.assertEqual(args.model, "SVC") + self.assertEqual(args.key, "metric") + + @patch("sys.stderr", new_callable=StringIO) + def test_xset_mandatory(self, mock_stderr): + arguments = self.build_args() + test_args = ["-n", "3", "-k", "date"] + with self.assertRaises(SystemExit): + arguments.parse(test_args) + self.assertRegexpMatches( + mock_stderr.getvalue(), + r"error: the following arguments are required: -m/--model", + ) + + @patch("sys.stderr", new_callable=StringIO) + def test_xset_required(self, mock_stderr): + arguments = self.build_args() + test_args = ["-n", "3", "-m", "SVC"] + with self.assertRaises(SystemExit): + arguments.parse(test_args) + self.assertRegexpMatches( + mock_stderr.getvalue(), + r"error: the following arguments are required: -k/--key", + ) diff --git a/benchmark/tests/BestResults_test.py b/benchmark/tests/BestResults_test.py index 761752c..6ee0bb1 100644 --- a/benchmark/tests/BestResults_test.py +++ b/benchmark/tests/BestResults_test.py @@ -62,3 +62,12 @@ class BestResultTest(TestBase): best.fill({}), {"balance-scale": (0.0, {}, ""), "balloons": (0.0, {}, "")}, ) + + def test_build_error(self): + dt = Datasets() + model = "SVC" + best = BestResults( + score="accuracy", model=model, datasets=dt, quiet=True + ) + with self.assertRaises(ValueError): + best.build() diff --git a/benchmark/tests/Report_test.py b/benchmark/tests/Report_test.py index cdc1ef4..1726270 100644 --- a/benchmark/tests/Report_test.py +++ b/benchmark/tests/Report_test.py @@ -1,7 +1,7 @@ from io import StringIO from unittest.mock import patch from .TestBase import TestBase -from ..Results import Report, BaseReport, ReportBest +from ..Results import Report, BaseReport, ReportBest, ReportDatasets from ..Utils import Symbols @@ -79,3 +79,9 @@ class ReportTest(TestBase): with patch(self.output, new=StringIO()) as fake_out: report.report() self.check_output_file(fake_out, "report_best.test") + + @patch("sys.stdout", new_callable=StringIO) + def test_report_datasets(self, mock_output): + report = ReportDatasets() + report.report() + self.check_output_file(mock_output, "report_datasets.test") diff --git a/benchmark/tests/Summary_test.py b/benchmark/tests/Summary_test.py index 644e0d8..e67aab3 100644 --- a/benchmark/tests/Summary_test.py +++ b/benchmark/tests/Summary_test.py @@ -215,3 +215,17 @@ class SummaryTest(TestBase): with patch(self.output, new=StringIO()) as fake_out: report.show_top() self.check_output_file(fake_out, "summary_show_top.test") + + @patch("sys.stdout", new_callable=StringIO) + def test_show_top_no_data(self, fake_out): + report = Summary() + report.acquire() + report.show_top(score="f1-macro") + self.assertEqual(fake_out.getvalue(), "** No results found **\n") + + def test_no_data(self): + report = Summary() + report.acquire() + with self.assertRaises(ValueError) as msg: + report.list_results(score="f1-macro", model="STree") + self.assertEqual(str(msg.exception), "** No results found **") diff --git a/benchmark/tests/__init__.py b/benchmark/tests/__init__.py index 689ba62..e998410 100644 --- a/benchmark/tests/__init__.py +++ b/benchmark/tests/__init__.py @@ -10,6 +10,7 @@ from .SQL_test import SQLTest from .Benchmark_test import BenchmarkTest from .Summary_test import SummaryTest from .PairCheck_test import PairCheckTest +from .Arguments_test import ArgumentsTest all = [ "UtilTest", @@ -24,5 +25,5 @@ all = [ "BenchmarkTest", "SummaryTest", "PairCheckTest", - "be_list", + "ArgumentsTest", ] diff --git a/benchmark/tests/test_files/report_datasets.test b/benchmark/tests/test_files/report_datasets.test new file mode 100644 index 0000000..d9581c0 --- /dev/null +++ b/benchmark/tests/test_files/report_datasets.test @@ -0,0 +1,4 @@ +Dataset Samp. Feat. Cls Balance +============================== ===== ===== === ======================================== +balance-scale 625 4 3 7.84%/ 46.08%/ 46.08% +balloons 16 4 2 56.25%/ 43.75%