mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-17 00:15:55 +00:00
Add some tests
This commit is contained in:
@@ -288,5 +288,5 @@ class Arguments:
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def parse(self):
|
def parse(self, args=None):
|
||||||
return self.ap.parse_args()
|
return self.ap.parse_args(args)
|
||||||
|
@@ -10,8 +10,8 @@ from benchmark.Arguments import Arguments
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
arguments = Arguments()
|
arguments = Arguments()
|
||||||
arguments.xset("number").xset("model", required=False).xset("score")
|
arguments.xset("number").xset("model", required=False).xset("key")
|
||||||
arguments.xset("hidden").xset("nan").xset("key")
|
arguments.xset("hidden").xset("nan").xset("score", required=False)
|
||||||
args = arguments.parse()
|
args = arguments.parse()
|
||||||
data = Summary(hidden=args.hidden)
|
data = Summary(hidden=args.hidden)
|
||||||
data.acquire()
|
data.acquire()
|
||||||
|
88
benchmark/tests/Arguments_test.py
Normal file
88
benchmark/tests/Arguments_test.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
from argparse import ArgumentError
|
||||||
|
from io import StringIO
|
||||||
|
from unittest.mock import patch
|
||||||
|
from .TestBase import TestBase
|
||||||
|
from ..Arguments import Arguments, ALL_METRICS
|
||||||
|
|
||||||
|
|
||||||
|
class ArgumentsTest(TestBase):
|
||||||
|
def build_args(self):
|
||||||
|
arguments = Arguments()
|
||||||
|
arguments.xset("n_folds").xset("model", mandatory=True)
|
||||||
|
arguments.xset("key", required=True)
|
||||||
|
return arguments
|
||||||
|
|
||||||
|
def test_build_hyperparams_file(self):
|
||||||
|
expected_metrics = (
|
||||||
|
"accuracy",
|
||||||
|
"f1_macro",
|
||||||
|
"f1_micro",
|
||||||
|
"f1_weighted",
|
||||||
|
"roc_auc_ovr",
|
||||||
|
)
|
||||||
|
self.assertSequenceEqual(ALL_METRICS, expected_metrics)
|
||||||
|
|
||||||
|
def test_parameters(self):
|
||||||
|
expected_parameters = {
|
||||||
|
"best": ("-b", "--best"),
|
||||||
|
"color": ("-c", "--color"),
|
||||||
|
"compare": ("-c", "--compare"),
|
||||||
|
"dataset": ("-d", "--dataset"),
|
||||||
|
"excel": ("-x", "--excel"),
|
||||||
|
"file": ("-f", "--file"),
|
||||||
|
"grid": ("-g", "--grid"),
|
||||||
|
"grid_paramfile": ("-g", "--grid_paramfile"),
|
||||||
|
"hidden": ("--hidden",),
|
||||||
|
"hyperparameters": ("-p", "--hyperparameters"),
|
||||||
|
"key": ("-k", "--key"),
|
||||||
|
"lose": ("-l", "--lose"),
|
||||||
|
"model": ("-m", "--model"),
|
||||||
|
"model1": ("-m1", "--model1"),
|
||||||
|
"model2": ("-m2", "--model2"),
|
||||||
|
"nan": ("--nan",),
|
||||||
|
"number": ("-n", "--number"),
|
||||||
|
"n_folds": ("-n", "--n_folds"),
|
||||||
|
"paramfile": ("-f", "--paramfile"),
|
||||||
|
"platform": ("-P", "--platform"),
|
||||||
|
"quiet": ("-q", "--quiet"),
|
||||||
|
"report": ("-r", "--report"),
|
||||||
|
"score": ("-s", "--score"),
|
||||||
|
"sql": ("-q", "--sql"),
|
||||||
|
"stratified": ("-t", "--stratified"),
|
||||||
|
"tex_output": ("-t", "--tex-output"),
|
||||||
|
"title": ("--title",),
|
||||||
|
"win": ("-w", "--win"),
|
||||||
|
}
|
||||||
|
arg = Arguments()
|
||||||
|
for key, value in expected_parameters.items():
|
||||||
|
self.assertSequenceEqual(arg.parameters[key][0], value, key)
|
||||||
|
|
||||||
|
def test_xset(self):
|
||||||
|
arguments = self.build_args()
|
||||||
|
test_args = ["-n", "3", "--model", "SVC", "-k", "metric"]
|
||||||
|
args = arguments.parse(test_args)
|
||||||
|
self.assertEqual(args.n_folds, 3)
|
||||||
|
self.assertEqual(args.model, "SVC")
|
||||||
|
self.assertEqual(args.key, "metric")
|
||||||
|
|
||||||
|
@patch("sys.stderr", new_callable=StringIO)
|
||||||
|
def test_xset_mandatory(self, mock_stderr):
|
||||||
|
arguments = self.build_args()
|
||||||
|
test_args = ["-n", "3", "-k", "date"]
|
||||||
|
with self.assertRaises(SystemExit):
|
||||||
|
arguments.parse(test_args)
|
||||||
|
self.assertRegexpMatches(
|
||||||
|
mock_stderr.getvalue(),
|
||||||
|
r"error: the following arguments are required: -m/--model",
|
||||||
|
)
|
||||||
|
|
||||||
|
@patch("sys.stderr", new_callable=StringIO)
|
||||||
|
def test_xset_required(self, mock_stderr):
|
||||||
|
arguments = self.build_args()
|
||||||
|
test_args = ["-n", "3", "-m", "SVC"]
|
||||||
|
with self.assertRaises(SystemExit):
|
||||||
|
arguments.parse(test_args)
|
||||||
|
self.assertRegexpMatches(
|
||||||
|
mock_stderr.getvalue(),
|
||||||
|
r"error: the following arguments are required: -k/--key",
|
||||||
|
)
|
@@ -62,3 +62,12 @@ class BestResultTest(TestBase):
|
|||||||
best.fill({}),
|
best.fill({}),
|
||||||
{"balance-scale": (0.0, {}, ""), "balloons": (0.0, {}, "")},
|
{"balance-scale": (0.0, {}, ""), "balloons": (0.0, {}, "")},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_build_error(self):
|
||||||
|
dt = Datasets()
|
||||||
|
model = "SVC"
|
||||||
|
best = BestResults(
|
||||||
|
score="accuracy", model=model, datasets=dt, quiet=True
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
best.build()
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
from io import StringIO
|
from io import StringIO
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
from .TestBase import TestBase
|
from .TestBase import TestBase
|
||||||
from ..Results import Report, BaseReport, ReportBest
|
from ..Results import Report, BaseReport, ReportBest, ReportDatasets
|
||||||
from ..Utils import Symbols
|
from ..Utils import Symbols
|
||||||
|
|
||||||
|
|
||||||
@@ -79,3 +79,9 @@ class ReportTest(TestBase):
|
|||||||
with patch(self.output, new=StringIO()) as fake_out:
|
with patch(self.output, new=StringIO()) as fake_out:
|
||||||
report.report()
|
report.report()
|
||||||
self.check_output_file(fake_out, "report_best.test")
|
self.check_output_file(fake_out, "report_best.test")
|
||||||
|
|
||||||
|
@patch("sys.stdout", new_callable=StringIO)
|
||||||
|
def test_report_datasets(self, mock_output):
|
||||||
|
report = ReportDatasets()
|
||||||
|
report.report()
|
||||||
|
self.check_output_file(mock_output, "report_datasets.test")
|
||||||
|
@@ -215,3 +215,17 @@ class SummaryTest(TestBase):
|
|||||||
with patch(self.output, new=StringIO()) as fake_out:
|
with patch(self.output, new=StringIO()) as fake_out:
|
||||||
report.show_top()
|
report.show_top()
|
||||||
self.check_output_file(fake_out, "summary_show_top.test")
|
self.check_output_file(fake_out, "summary_show_top.test")
|
||||||
|
|
||||||
|
@patch("sys.stdout", new_callable=StringIO)
|
||||||
|
def test_show_top_no_data(self, fake_out):
|
||||||
|
report = Summary()
|
||||||
|
report.acquire()
|
||||||
|
report.show_top(score="f1-macro")
|
||||||
|
self.assertEqual(fake_out.getvalue(), "** No results found **\n")
|
||||||
|
|
||||||
|
def test_no_data(self):
|
||||||
|
report = Summary()
|
||||||
|
report.acquire()
|
||||||
|
with self.assertRaises(ValueError) as msg:
|
||||||
|
report.list_results(score="f1-macro", model="STree")
|
||||||
|
self.assertEqual(str(msg.exception), "** No results found **")
|
||||||
|
@@ -10,6 +10,7 @@ from .SQL_test import SQLTest
|
|||||||
from .Benchmark_test import BenchmarkTest
|
from .Benchmark_test import BenchmarkTest
|
||||||
from .Summary_test import SummaryTest
|
from .Summary_test import SummaryTest
|
||||||
from .PairCheck_test import PairCheckTest
|
from .PairCheck_test import PairCheckTest
|
||||||
|
from .Arguments_test import ArgumentsTest
|
||||||
|
|
||||||
all = [
|
all = [
|
||||||
"UtilTest",
|
"UtilTest",
|
||||||
@@ -24,5 +25,5 @@ all = [
|
|||||||
"BenchmarkTest",
|
"BenchmarkTest",
|
||||||
"SummaryTest",
|
"SummaryTest",
|
||||||
"PairCheckTest",
|
"PairCheckTest",
|
||||||
"be_list",
|
"ArgumentsTest",
|
||||||
]
|
]
|
||||||
|
4
benchmark/tests/test_files/report_datasets.test
Normal file
4
benchmark/tests/test_files/report_datasets.test
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
[94mDataset Samp. Feat. Cls Balance
|
||||||
|
============================== ===== ===== === ========================================
|
||||||
|
[96mbalance-scale 625 4 3 7.84%/ 46.08%/ 46.08%
|
||||||
|
[94mballoons 16 4 2 56.25%/ 43.75%
|
Reference in New Issue
Block a user