Refactor testing

This commit is contained in:
2022-05-07 01:33:35 +02:00
parent 3b214773ff
commit df757fefcd
20 changed files with 92 additions and 63 deletions

View File

@@ -4,10 +4,10 @@ from benchmark.Utils import Files
from benchmark.Arguments import Arguments
def main():
def main(args_test=None):
arguments = Arguments()
arguments.xset("score").xset("excel").xset("tex_output")
ar = arguments.parse()
ar = arguments.parse(args_test)
benchmark = Benchmark(score=ar.score, visualize=True)
benchmark.compile_results()
benchmark.save_results()

View File

@@ -4,12 +4,12 @@ from benchmark.Results import Summary
from benchmark.Arguments import ALL_METRICS, Arguments
def main():
def main(args_test=None):
arguments = Arguments()
metrics = list(ALL_METRICS)
metrics.append("all")
arguments.xset("score", choices=metrics)
args = arguments.parse()
args = arguments.parse(args_test)
metrics = ALL_METRICS if args.score == "all" else [args.score]
summary = Summary()
summary.acquire()

View File

@@ -7,11 +7,11 @@ from benchmark.Arguments import Arguments
"""
def main():
def main(args_test=None):
arguments = Arguments()
arguments.xset("score", mandatory=True).xset("report")
arguments.xset("model", mandatory=True)
args = arguments.parse()
args = arguments.parse(args_test)
datasets = Datasets()
best = BestResults(args.score, args.model, datasets)
try:

View File

@@ -9,10 +9,10 @@ input grid used optimizing STree
"""
def main():
def main(args_test=None):
arguments = Arguments()
arguments.xset("model", mandatory=True).xset("score", mandatory=True)
args = arguments.parse()
args = arguments.parse(args_test)
data = [
'{"C": 1e4, "gamma": 0.1, "kernel": "rbf"}',
'{"C": 7, "gamma": 0.14, "kernel": "rbf"}',

View File

@@ -6,11 +6,11 @@ from benchmark.Arguments import Arguments
"""
def main():
def main(args_test=None):
arguments = Arguments()
arguments.xset("score").xset("platform").xset("model", mandatory=True)
arguments.xset("quiet").xset("stratified").xset("dataset").xset("n_folds")
args = arguments.parse()
args = arguments.parse(args_test)
if not args.quiet:
print(f"Perform grid search with {args.model} model")
job = GridSearch(

View File

@@ -8,13 +8,13 @@ from benchmark.Arguments import Arguments
"""
def main():
def main(args_test=None):
arguments = Arguments()
arguments.xset("stratified").xset("score").xset("model", mandatory=True)
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
arguments.xset("hyperparameters").xset("paramfile").xset("report")
arguments.xset("grid_paramfile").xset("dataset")
args = arguments.parse()
args = arguments.parse(args_test)
report = args.report or args.dataset is not None
if args.grid_paramfile:
args.paramfile = False

View File

@@ -81,10 +81,10 @@ def print_stree(clf, dataset, X, y, color, quiet):
subprocess.run([cmd_open, f"{file_name}.png"])
def main():
def main(args_test=None):
arguments = Arguments()
arguments.xset("color").xset("dataset", default="all").xset("quiet")
args = arguments.parse()
args = arguments.parse(args_test)
hyperparameters = load_hyperparams("accuracy", "ODTE")
random_state = 57
dt = Datasets()

View File

@@ -10,13 +10,13 @@ If no argument is set, displays the datasets and its characteristics
"""
def main():
def main(args_test=None):
arguments = Arguments()
arguments.xset("file").xset("excel").xset("sql").xset("compare")
arguments.xset("best").xset("grid").xset("model", required=False).xset(
"score"
)
args = arguments.parse()
args = arguments.parse(args_test)
if args.grid:
args.best = False
if args.file is None and args.best is None:

View File

@@ -3,12 +3,12 @@ from benchmark.Results import Summary
from benchmark.Arguments import ALL_METRICS, Arguments
def main():
def main(args_test=None):
arguments = Arguments()
metrics = list(ALL_METRICS)
metrics.append("all")
arguments.xset("score", choices=metrics).xset("model", required=False)
args = arguments.parse()
args = arguments.parse(args_test)
metrics = ALL_METRICS if args.score == "all" else [args.score]
summary = Summary()
summary.acquire()

View File

@@ -29,9 +29,7 @@ class BenchmarkTest(TestBase):
benchmark = Benchmark("accuracy", visualize=False)
benchmark.compile_results()
benchmark.save_results()
self.check_file_file(
benchmark.get_result_file_name(), "exreport_csv.test"
)
self.check_file_file(benchmark.get_result_file_name(), "exreport_csv")
def test_exreport_report(self):
benchmark = Benchmark("accuracy", visualize=False)
@@ -39,7 +37,7 @@ class BenchmarkTest(TestBase):
benchmark.save_results()
with patch(self.output, new=StringIO()) as fake_out:
benchmark.report(tex_output=False)
self.check_output_file(fake_out, "exreport_report.test")
self.check_output_file(fake_out, "exreport_report")
def test_exreport(self):
benchmark = Benchmark("accuracy", visualize=False)
@@ -75,7 +73,7 @@ class BenchmarkTest(TestBase):
benchmark.save_results()
with patch(self.output, new=StringIO()) as fake_out:
benchmark.exreport()
self.check_output_file(fake_out, "exreport_error.test")
self.check_output_file(fake_out, "exreport_error")
def test_tex_output(self):
benchmark = Benchmark("accuracy", visualize=False)
@@ -83,11 +81,9 @@ class BenchmarkTest(TestBase):
benchmark.save_results()
with patch(self.output, new=StringIO()) as fake_out:
benchmark.report(tex_output=True)
with open(os.path.join(self.test_files, "exreport_report.test")) as f:
expected = f.read()
self.assertEqual(fake_out.getvalue(), expected)
self.check_output_file(fake_out, "exreport_report")
self.assertTrue(os.path.exists(benchmark.get_tex_file()))
self.check_file_file(benchmark.get_tex_file(), "exreport_tex.test")
self.check_file_file(benchmark.get_tex_file(), "exreport_tex")
def test_excel_output(self):
benchmark = Benchmark("accuracy", visualize=False)
@@ -100,7 +96,7 @@ class BenchmarkTest(TestBase):
book = load_workbook(file_name)
for sheet_name in book.sheetnames:
sheet = book[sheet_name]
self.check_excel_sheet(sheet, f"exreport_excel_{sheet_name}.test")
self.check_excel_sheet(sheet, f"exreport_excel_{sheet_name}")
# ExcelTest.generate_excel_sheet(
# self, sheet, f"exreport_excel_{sheet_name}.test"
# self, sheet, f"exreport_excel_{sheet_name}"
# )

View File

@@ -23,7 +23,7 @@ class ExcelTest(TestBase):
file_output = report.get_file_name()
book = load_workbook(file_output)
sheet = book["STree"]
self.check_excel_sheet(sheet, "excel_compared.test")
self.check_excel_sheet(sheet, "excel_compared")
def test_report_excel(self):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
@@ -32,7 +32,7 @@ class ExcelTest(TestBase):
file_output = report.get_file_name()
book = load_workbook(file_output)
sheet = book["STree"]
self.check_excel_sheet(sheet, "excel.test")
self.check_excel_sheet(sheet, "excel")
def test_Excel_Add_sheet(self):
file_name = "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json"
@@ -48,6 +48,6 @@ class ExcelTest(TestBase):
book.close()
book = load_workbook(os.path.join(Folders.results, excel_file_name))
sheet = book["STree"]
self.check_excel_sheet(sheet, "excel_add_STree.test")
self.check_excel_sheet(sheet, "excel_add_STree")
sheet = book["ODTE"]
self.check_excel_sheet(sheet, "excel_add_ODTE.test")
self.check_excel_sheet(sheet, "excel_add_ODTE")

View File

@@ -21,14 +21,14 @@ class PairCheckTest(TestBase):
report.compute()
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "paircheck.test")
self.check_output_file(fake_out, "paircheck")
def test_pair_check_win(self):
report = self.build_model(win=True)
report.compute()
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "paircheck_win.test")
self.check_output_file(fake_out, "paircheck_win")
def test_pair_check_lose(self):
report = self.build_model(
@@ -37,14 +37,14 @@ class PairCheckTest(TestBase):
report.compute()
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "paircheck_lose.test")
self.check_output_file(fake_out, "paircheck_lose")
def test_pair_check_win_lose(self):
report = self.build_model(win=True, lose=True)
report.compute()
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "paircheck_win_lose.test")
self.check_output_file(fake_out, "paircheck_win_lose")
def test_pair_check_store_result(self):
report = self.build_model(win=True, lose=True)

View File

@@ -1,3 +1,4 @@
import os
from io import StringIO
from unittest.mock import patch
from .TestBase import TestBase
@@ -8,9 +9,9 @@ from ..Utils import Symbols
class ReportTest(TestBase):
def test_BaseReport(self):
with patch.multiple(BaseReport, __abstractmethods__=set()):
file_name = (
"results/results_accuracy_STree_iMac27_2021-09-30_11:"
"42:07_0.json"
file_name = os.path.join(
"results",
"results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json",
)
a = BaseReport(file_name)
self.assertIsNone(a.header())
@@ -19,12 +20,14 @@ class ReportTest(TestBase):
def test_report_with_folder(self):
report = Report(
file_name="results/results_accuracy_STree_iMac27_2021-09-30_11:"
"42:07_0.json"
file_name=os.path.join(
"results",
"results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json",
)
)
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "report.test")
self.check_output_file(fake_out, "report")
def test_report_without_folder(self):
report = Report(
@@ -33,7 +36,7 @@ class ReportTest(TestBase):
)
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "report.test")
self.check_output_file(fake_out, "report")
def test_report_compared(self):
report = Report(
@@ -43,7 +46,7 @@ class ReportTest(TestBase):
)
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "report_compared.test")
self.check_output_file(fake_out, "report_compared")
def test_compute_status(self):
file_name = "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json"
@@ -66,22 +69,22 @@ class ReportTest(TestBase):
report = ReportBest("accuracy", "STree", best=True, grid=False)
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "report_best.test")
self.check_output_file(fake_out, "report_best")
def test_report_grid(self):
report = ReportBest("accuracy", "STree", best=False, grid=True)
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "report_grid.test")
self.check_output_file(fake_out, "report_grid")
def test_report_best_both(self):
report = ReportBest("accuracy", "STree", best=True, grid=True)
with patch(self.output, new=StringIO()) as fake_out:
report.report()
self.check_output_file(fake_out, "report_best.test")
self.check_output_file(fake_out, "report_best")
@patch("sys.stdout", new_callable=StringIO)
def test_report_datasets(self, mock_output):
report = ReportDatasets()
report.report()
self.check_output_file(mock_output, "report_datasets.test")
self.check_output_file(mock_output, "report_datasets")

View File

@@ -19,4 +19,4 @@ class SQLTest(TestBase):
file_name = os.path.join(
Folders.results, file_name.replace(".json", ".sql")
)
self.check_file_file(file_name, "sql.test")
self.check_file_file(file_name, "sql")

View File

@@ -132,28 +132,28 @@ class SummaryTest(TestBase):
report.acquire()
with patch(self.output, new=StringIO()) as fake_out:
report.list_results(model="STree")
self.check_output_file(fake_out, "summary_list_model.test")
self.check_output_file(fake_out, "summary_list_model")
def test_summary_list_results_score(self):
report = Summary()
report.acquire()
with patch(self.output, new=StringIO()) as fake_out:
report.list_results(score="accuracy")
self.check_output_file(fake_out, "summary_list_score.test")
self.check_output_file(fake_out, "summary_list_score")
def test_summary_list_results_n(self):
report = Summary()
report.acquire()
with patch(self.output, new=StringIO()) as fake_out:
report.list_results(score="accuracy", number=3)
self.check_output_file(fake_out, "summary_list_n.test")
self.check_output_file(fake_out, "summary_list_n")
def test_summary_list_hidden(self):
report = Summary(hidden=True)
report.acquire()
with patch(self.output, new=StringIO()) as fake_out:
report.list_results(score="accuracy")
self.check_output_file(fake_out, "summary_list_hidden.test")
self.check_output_file(fake_out, "summary_list_hidden")
def test_show_result_no_title(self):
report = Summary()
@@ -164,7 +164,7 @@ class SummaryTest(TestBase):
criterion="model", value="STree", score="accuracy"
)
report.show_result(data=best, title=title)
self.check_output_file(fake_out, "summary_show_results.test")
self.check_output_file(fake_out, "summary_show_results")
def test_show_result_title(self):
report = Summary()
@@ -175,7 +175,7 @@ class SummaryTest(TestBase):
criterion="model", value="STree", score="accuracy"
)
report.show_result(data=best, title=title)
self.check_output_file(fake_out, "summary_show_results_title.test")
self.check_output_file(fake_out, "summary_show_results_title")
def test_show_result_no_data(self):
report = Summary()
@@ -214,7 +214,7 @@ class SummaryTest(TestBase):
report.acquire()
with patch(self.output, new=StringIO()) as fake_out:
report.show_top()
self.check_output_file(fake_out, "summary_show_top.test")
self.check_output_file(fake_out, "summary_show_top")
@patch("sys.stdout", new_callable=StringIO)
def test_show_top_no_data(self, fake_out):

View File

@@ -31,6 +31,7 @@ class TestBase(unittest.TestCase):
print(f'{row};{col};"{value}"', file=f)
def check_excel_sheet(self, sheet, file_name):
file_name += ".test"
with open(os.path.join(self.test_files, file_name), "r") as f:
expected = csv.reader(f, delimiter=";")
for row, col, value in expected:
@@ -44,6 +45,7 @@ class TestBase(unittest.TestCase):
self.assertEqual(sheet.cell(int(row), int(col)).value, value)
def check_output_file(self, output, file_name):
file_name += ".test"
with open(os.path.join(self.test_files, file_name)) as f:
expected = f.read()
self.assertEqual(output.getvalue(), expected)
@@ -51,6 +53,7 @@ class TestBase(unittest.TestCase):
def check_file_file(self, computed_file, expected_file):
with open(computed_file) as f:
computed = f.read()
expected_file += ".test"
with open(os.path.join(self.test_files, expected_file)) as f:
expected = f.read()
self.assertEqual(computed, expected)

View File

@@ -11,8 +11,9 @@ from .Benchmark_test import BenchmarkTest
from .Summary_test import SummaryTest
from .PairCheck_test import PairCheckTest
from .Arguments_test import ArgumentsTest
from .scripts.Pair_check_test import BePairCheckTest
from .scripts.List_test import ListTest
from .scripts.Be_Pair_check_test import BePairCheckTest
from .scripts.Be_List_test import BeListTest
from .scripts.Be_Report_test import BeReportTest
all = [
"UtilTest",
@@ -29,5 +30,6 @@ all = [
"PairCheckTest",
"ArgumentsTest",
"BePairCheckTest",
"ListTest",
"BeListTest",
"BeReportTest",
]

View File

@@ -1,14 +1,14 @@
from ..TestBase import TestBase
class ListTest(TestBase):
class BeListTest(TestBase):
def setUp(self):
self.prepare_scripts_env()
def test_be_list(self):
stdout, stderr = self.execute_script("be_list", ["-m", "STree"])
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "summary_list_model.test")
self.check_output_file(stdout, "summary_list_model")
def test_be_list_no_data(self):
stdout, stderr = self.execute_script(

View File

@@ -10,7 +10,7 @@ class BePairCheckTest(TestBase):
"be_pair_check", ["-m1", "ODTE", "-m2", "STree"]
)
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "paircheck.test")
self.check_output_file(stdout, "paircheck")
def test_be_pair_check_no_data_a(self):
stdout, stderr = self.execute_script(

View File

@@ -0,0 +1,25 @@
import os
from ..TestBase import TestBase
class BeReportTest(TestBase):
def setUp(self):
self.prepare_scripts_env()
def test_be_report(self):
file_name = os.path.join(
"results",
"results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json",
)
stdout, stderr = self.execute_script("be_report", ["-f", file_name])
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report")
def test_be_report_compare(self):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
stdout, stderr = self.execute_script(
"be_report",
["-f", file_name, "-c", "1"],
)
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_compared")