mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-15 23:45:54 +00:00
Add some tests
This commit is contained in:
@@ -288,5 +288,5 @@ class Arguments:
|
||||
)
|
||||
return self
|
||||
|
||||
def parse(self):
|
||||
return self.ap.parse_args()
|
||||
def parse(self, args=None):
|
||||
return self.ap.parse_args(args)
|
||||
|
@@ -10,8 +10,8 @@ from benchmark.Arguments import Arguments
|
||||
|
||||
def main():
|
||||
arguments = Arguments()
|
||||
arguments.xset("number").xset("model", required=False).xset("score")
|
||||
arguments.xset("hidden").xset("nan").xset("key")
|
||||
arguments.xset("number").xset("model", required=False).xset("key")
|
||||
arguments.xset("hidden").xset("nan").xset("score", required=False)
|
||||
args = arguments.parse()
|
||||
data = Summary(hidden=args.hidden)
|
||||
data.acquire()
|
||||
|
88
benchmark/tests/Arguments_test.py
Normal file
88
benchmark/tests/Arguments_test.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from argparse import ArgumentError
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
from .TestBase import TestBase
|
||||
from ..Arguments import Arguments, ALL_METRICS
|
||||
|
||||
|
||||
class ArgumentsTest(TestBase):
|
||||
def build_args(self):
|
||||
arguments = Arguments()
|
||||
arguments.xset("n_folds").xset("model", mandatory=True)
|
||||
arguments.xset("key", required=True)
|
||||
return arguments
|
||||
|
||||
def test_build_hyperparams_file(self):
|
||||
expected_metrics = (
|
||||
"accuracy",
|
||||
"f1_macro",
|
||||
"f1_micro",
|
||||
"f1_weighted",
|
||||
"roc_auc_ovr",
|
||||
)
|
||||
self.assertSequenceEqual(ALL_METRICS, expected_metrics)
|
||||
|
||||
def test_parameters(self):
|
||||
expected_parameters = {
|
||||
"best": ("-b", "--best"),
|
||||
"color": ("-c", "--color"),
|
||||
"compare": ("-c", "--compare"),
|
||||
"dataset": ("-d", "--dataset"),
|
||||
"excel": ("-x", "--excel"),
|
||||
"file": ("-f", "--file"),
|
||||
"grid": ("-g", "--grid"),
|
||||
"grid_paramfile": ("-g", "--grid_paramfile"),
|
||||
"hidden": ("--hidden",),
|
||||
"hyperparameters": ("-p", "--hyperparameters"),
|
||||
"key": ("-k", "--key"),
|
||||
"lose": ("-l", "--lose"),
|
||||
"model": ("-m", "--model"),
|
||||
"model1": ("-m1", "--model1"),
|
||||
"model2": ("-m2", "--model2"),
|
||||
"nan": ("--nan",),
|
||||
"number": ("-n", "--number"),
|
||||
"n_folds": ("-n", "--n_folds"),
|
||||
"paramfile": ("-f", "--paramfile"),
|
||||
"platform": ("-P", "--platform"),
|
||||
"quiet": ("-q", "--quiet"),
|
||||
"report": ("-r", "--report"),
|
||||
"score": ("-s", "--score"),
|
||||
"sql": ("-q", "--sql"),
|
||||
"stratified": ("-t", "--stratified"),
|
||||
"tex_output": ("-t", "--tex-output"),
|
||||
"title": ("--title",),
|
||||
"win": ("-w", "--win"),
|
||||
}
|
||||
arg = Arguments()
|
||||
for key, value in expected_parameters.items():
|
||||
self.assertSequenceEqual(arg.parameters[key][0], value, key)
|
||||
|
||||
def test_xset(self):
|
||||
arguments = self.build_args()
|
||||
test_args = ["-n", "3", "--model", "SVC", "-k", "metric"]
|
||||
args = arguments.parse(test_args)
|
||||
self.assertEqual(args.n_folds, 3)
|
||||
self.assertEqual(args.model, "SVC")
|
||||
self.assertEqual(args.key, "metric")
|
||||
|
||||
@patch("sys.stderr", new_callable=StringIO)
|
||||
def test_xset_mandatory(self, mock_stderr):
|
||||
arguments = self.build_args()
|
||||
test_args = ["-n", "3", "-k", "date"]
|
||||
with self.assertRaises(SystemExit):
|
||||
arguments.parse(test_args)
|
||||
self.assertRegexpMatches(
|
||||
mock_stderr.getvalue(),
|
||||
r"error: the following arguments are required: -m/--model",
|
||||
)
|
||||
|
||||
@patch("sys.stderr", new_callable=StringIO)
|
||||
def test_xset_required(self, mock_stderr):
|
||||
arguments = self.build_args()
|
||||
test_args = ["-n", "3", "-m", "SVC"]
|
||||
with self.assertRaises(SystemExit):
|
||||
arguments.parse(test_args)
|
||||
self.assertRegexpMatches(
|
||||
mock_stderr.getvalue(),
|
||||
r"error: the following arguments are required: -k/--key",
|
||||
)
|
@@ -62,3 +62,12 @@ class BestResultTest(TestBase):
|
||||
best.fill({}),
|
||||
{"balance-scale": (0.0, {}, ""), "balloons": (0.0, {}, "")},
|
||||
)
|
||||
|
||||
def test_build_error(self):
|
||||
dt = Datasets()
|
||||
model = "SVC"
|
||||
best = BestResults(
|
||||
score="accuracy", model=model, datasets=dt, quiet=True
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
best.build()
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
from .TestBase import TestBase
|
||||
from ..Results import Report, BaseReport, ReportBest
|
||||
from ..Results import Report, BaseReport, ReportBest, ReportDatasets
|
||||
from ..Utils import Symbols
|
||||
|
||||
|
||||
@@ -79,3 +79,9 @@ class ReportTest(TestBase):
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "report_best.test")
|
||||
|
||||
@patch("sys.stdout", new_callable=StringIO)
|
||||
def test_report_datasets(self, mock_output):
|
||||
report = ReportDatasets()
|
||||
report.report()
|
||||
self.check_output_file(mock_output, "report_datasets.test")
|
||||
|
@@ -215,3 +215,17 @@ class SummaryTest(TestBase):
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.show_top()
|
||||
self.check_output_file(fake_out, "summary_show_top.test")
|
||||
|
||||
@patch("sys.stdout", new_callable=StringIO)
|
||||
def test_show_top_no_data(self, fake_out):
|
||||
report = Summary()
|
||||
report.acquire()
|
||||
report.show_top(score="f1-macro")
|
||||
self.assertEqual(fake_out.getvalue(), "** No results found **\n")
|
||||
|
||||
def test_no_data(self):
|
||||
report = Summary()
|
||||
report.acquire()
|
||||
with self.assertRaises(ValueError) as msg:
|
||||
report.list_results(score="f1-macro", model="STree")
|
||||
self.assertEqual(str(msg.exception), "** No results found **")
|
||||
|
@@ -10,6 +10,7 @@ from .SQL_test import SQLTest
|
||||
from .Benchmark_test import BenchmarkTest
|
||||
from .Summary_test import SummaryTest
|
||||
from .PairCheck_test import PairCheckTest
|
||||
from .Arguments_test import ArgumentsTest
|
||||
|
||||
all = [
|
||||
"UtilTest",
|
||||
@@ -24,5 +25,5 @@ all = [
|
||||
"BenchmarkTest",
|
||||
"SummaryTest",
|
||||
"PairCheckTest",
|
||||
"be_list",
|
||||
"ArgumentsTest",
|
||||
]
|
||||
|
4
benchmark/tests/test_files/report_datasets.test
Normal file
4
benchmark/tests/test_files/report_datasets.test
Normal file
@@ -0,0 +1,4 @@
|
||||
[94mDataset Samp. Feat. Cls Balance
|
||||
============================== ===== ===== === ========================================
|
||||
[96mbalance-scale 625 4 3 7.84%/ 46.08%/ 46.08%
|
||||
[94mballoons 16 4 2 56.25%/ 43.75%
|
Reference in New Issue
Block a user