mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-17 16:35:54 +00:00
Refactor testing
This commit is contained in:
@@ -4,10 +4,10 @@ from benchmark.Utils import Files
|
||||
from benchmark.Arguments import Arguments
|
||||
|
||||
|
||||
def main():
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
arguments.xset("score").xset("excel").xset("tex_output")
|
||||
ar = arguments.parse()
|
||||
ar = arguments.parse(args_test)
|
||||
benchmark = Benchmark(score=ar.score, visualize=True)
|
||||
benchmark.compile_results()
|
||||
benchmark.save_results()
|
||||
|
@@ -4,12 +4,12 @@ from benchmark.Results import Summary
|
||||
from benchmark.Arguments import ALL_METRICS, Arguments
|
||||
|
||||
|
||||
def main():
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
metrics = list(ALL_METRICS)
|
||||
metrics.append("all")
|
||||
arguments.xset("score", choices=metrics)
|
||||
args = arguments.parse()
|
||||
args = arguments.parse(args_test)
|
||||
metrics = ALL_METRICS if args.score == "all" else [args.score]
|
||||
summary = Summary()
|
||||
summary.acquire()
|
||||
|
@@ -7,11 +7,11 @@ from benchmark.Arguments import Arguments
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
arguments.xset("score", mandatory=True).xset("report")
|
||||
arguments.xset("model", mandatory=True)
|
||||
args = arguments.parse()
|
||||
args = arguments.parse(args_test)
|
||||
datasets = Datasets()
|
||||
best = BestResults(args.score, args.model, datasets)
|
||||
try:
|
||||
|
@@ -9,10 +9,10 @@ input grid used optimizing STree
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
arguments.xset("model", mandatory=True).xset("score", mandatory=True)
|
||||
args = arguments.parse()
|
||||
args = arguments.parse(args_test)
|
||||
data = [
|
||||
'{"C": 1e4, "gamma": 0.1, "kernel": "rbf"}',
|
||||
'{"C": 7, "gamma": 0.14, "kernel": "rbf"}',
|
||||
|
@@ -6,11 +6,11 @@ from benchmark.Arguments import Arguments
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
arguments.xset("score").xset("platform").xset("model", mandatory=True)
|
||||
arguments.xset("quiet").xset("stratified").xset("dataset").xset("n_folds")
|
||||
args = arguments.parse()
|
||||
args = arguments.parse(args_test)
|
||||
if not args.quiet:
|
||||
print(f"Perform grid search with {args.model} model")
|
||||
job = GridSearch(
|
||||
|
@@ -8,13 +8,13 @@ from benchmark.Arguments import Arguments
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
arguments.xset("stratified").xset("score").xset("model", mandatory=True)
|
||||
arguments.xset("n_folds").xset("platform").xset("quiet").xset("title")
|
||||
arguments.xset("hyperparameters").xset("paramfile").xset("report")
|
||||
arguments.xset("grid_paramfile").xset("dataset")
|
||||
args = arguments.parse()
|
||||
args = arguments.parse(args_test)
|
||||
report = args.report or args.dataset is not None
|
||||
if args.grid_paramfile:
|
||||
args.paramfile = False
|
||||
|
@@ -81,10 +81,10 @@ def print_stree(clf, dataset, X, y, color, quiet):
|
||||
subprocess.run([cmd_open, f"{file_name}.png"])
|
||||
|
||||
|
||||
def main():
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
arguments.xset("color").xset("dataset", default="all").xset("quiet")
|
||||
args = arguments.parse()
|
||||
args = arguments.parse(args_test)
|
||||
hyperparameters = load_hyperparams("accuracy", "ODTE")
|
||||
random_state = 57
|
||||
dt = Datasets()
|
||||
|
@@ -10,13 +10,13 @@ If no argument is set, displays the datasets and its characteristics
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
arguments.xset("file").xset("excel").xset("sql").xset("compare")
|
||||
arguments.xset("best").xset("grid").xset("model", required=False).xset(
|
||||
"score"
|
||||
)
|
||||
args = arguments.parse()
|
||||
args = arguments.parse(args_test)
|
||||
if args.grid:
|
||||
args.best = False
|
||||
if args.file is None and args.best is None:
|
||||
|
@@ -3,12 +3,12 @@ from benchmark.Results import Summary
|
||||
from benchmark.Arguments import ALL_METRICS, Arguments
|
||||
|
||||
|
||||
def main():
|
||||
def main(args_test=None):
|
||||
arguments = Arguments()
|
||||
metrics = list(ALL_METRICS)
|
||||
metrics.append("all")
|
||||
arguments.xset("score", choices=metrics).xset("model", required=False)
|
||||
args = arguments.parse()
|
||||
args = arguments.parse(args_test)
|
||||
metrics = ALL_METRICS if args.score == "all" else [args.score]
|
||||
summary = Summary()
|
||||
summary.acquire()
|
||||
|
@@ -29,9 +29,7 @@ class BenchmarkTest(TestBase):
|
||||
benchmark = Benchmark("accuracy", visualize=False)
|
||||
benchmark.compile_results()
|
||||
benchmark.save_results()
|
||||
self.check_file_file(
|
||||
benchmark.get_result_file_name(), "exreport_csv.test"
|
||||
)
|
||||
self.check_file_file(benchmark.get_result_file_name(), "exreport_csv")
|
||||
|
||||
def test_exreport_report(self):
|
||||
benchmark = Benchmark("accuracy", visualize=False)
|
||||
@@ -39,7 +37,7 @@ class BenchmarkTest(TestBase):
|
||||
benchmark.save_results()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
benchmark.report(tex_output=False)
|
||||
self.check_output_file(fake_out, "exreport_report.test")
|
||||
self.check_output_file(fake_out, "exreport_report")
|
||||
|
||||
def test_exreport(self):
|
||||
benchmark = Benchmark("accuracy", visualize=False)
|
||||
@@ -75,7 +73,7 @@ class BenchmarkTest(TestBase):
|
||||
benchmark.save_results()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
benchmark.exreport()
|
||||
self.check_output_file(fake_out, "exreport_error.test")
|
||||
self.check_output_file(fake_out, "exreport_error")
|
||||
|
||||
def test_tex_output(self):
|
||||
benchmark = Benchmark("accuracy", visualize=False)
|
||||
@@ -83,11 +81,9 @@ class BenchmarkTest(TestBase):
|
||||
benchmark.save_results()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
benchmark.report(tex_output=True)
|
||||
with open(os.path.join(self.test_files, "exreport_report.test")) as f:
|
||||
expected = f.read()
|
||||
self.assertEqual(fake_out.getvalue(), expected)
|
||||
self.check_output_file(fake_out, "exreport_report")
|
||||
self.assertTrue(os.path.exists(benchmark.get_tex_file()))
|
||||
self.check_file_file(benchmark.get_tex_file(), "exreport_tex.test")
|
||||
self.check_file_file(benchmark.get_tex_file(), "exreport_tex")
|
||||
|
||||
def test_excel_output(self):
|
||||
benchmark = Benchmark("accuracy", visualize=False)
|
||||
@@ -100,7 +96,7 @@ class BenchmarkTest(TestBase):
|
||||
book = load_workbook(file_name)
|
||||
for sheet_name in book.sheetnames:
|
||||
sheet = book[sheet_name]
|
||||
self.check_excel_sheet(sheet, f"exreport_excel_{sheet_name}.test")
|
||||
self.check_excel_sheet(sheet, f"exreport_excel_{sheet_name}")
|
||||
# ExcelTest.generate_excel_sheet(
|
||||
# self, sheet, f"exreport_excel_{sheet_name}.test"
|
||||
# self, sheet, f"exreport_excel_{sheet_name}"
|
||||
# )
|
||||
|
@@ -23,7 +23,7 @@ class ExcelTest(TestBase):
|
||||
file_output = report.get_file_name()
|
||||
book = load_workbook(file_output)
|
||||
sheet = book["STree"]
|
||||
self.check_excel_sheet(sheet, "excel_compared.test")
|
||||
self.check_excel_sheet(sheet, "excel_compared")
|
||||
|
||||
def test_report_excel(self):
|
||||
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
|
||||
@@ -32,7 +32,7 @@ class ExcelTest(TestBase):
|
||||
file_output = report.get_file_name()
|
||||
book = load_workbook(file_output)
|
||||
sheet = book["STree"]
|
||||
self.check_excel_sheet(sheet, "excel.test")
|
||||
self.check_excel_sheet(sheet, "excel")
|
||||
|
||||
def test_Excel_Add_sheet(self):
|
||||
file_name = "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json"
|
||||
@@ -48,6 +48,6 @@ class ExcelTest(TestBase):
|
||||
book.close()
|
||||
book = load_workbook(os.path.join(Folders.results, excel_file_name))
|
||||
sheet = book["STree"]
|
||||
self.check_excel_sheet(sheet, "excel_add_STree.test")
|
||||
self.check_excel_sheet(sheet, "excel_add_STree")
|
||||
sheet = book["ODTE"]
|
||||
self.check_excel_sheet(sheet, "excel_add_ODTE.test")
|
||||
self.check_excel_sheet(sheet, "excel_add_ODTE")
|
||||
|
@@ -21,14 +21,14 @@ class PairCheckTest(TestBase):
|
||||
report.compute()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "paircheck.test")
|
||||
self.check_output_file(fake_out, "paircheck")
|
||||
|
||||
def test_pair_check_win(self):
|
||||
report = self.build_model(win=True)
|
||||
report.compute()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "paircheck_win.test")
|
||||
self.check_output_file(fake_out, "paircheck_win")
|
||||
|
||||
def test_pair_check_lose(self):
|
||||
report = self.build_model(
|
||||
@@ -37,14 +37,14 @@ class PairCheckTest(TestBase):
|
||||
report.compute()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "paircheck_lose.test")
|
||||
self.check_output_file(fake_out, "paircheck_lose")
|
||||
|
||||
def test_pair_check_win_lose(self):
|
||||
report = self.build_model(win=True, lose=True)
|
||||
report.compute()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "paircheck_win_lose.test")
|
||||
self.check_output_file(fake_out, "paircheck_win_lose")
|
||||
|
||||
def test_pair_check_store_result(self):
|
||||
report = self.build_model(win=True, lose=True)
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
from .TestBase import TestBase
|
||||
@@ -8,9 +9,9 @@ from ..Utils import Symbols
|
||||
class ReportTest(TestBase):
|
||||
def test_BaseReport(self):
|
||||
with patch.multiple(BaseReport, __abstractmethods__=set()):
|
||||
file_name = (
|
||||
"results/results_accuracy_STree_iMac27_2021-09-30_11:"
|
||||
"42:07_0.json"
|
||||
file_name = os.path.join(
|
||||
"results",
|
||||
"results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json",
|
||||
)
|
||||
a = BaseReport(file_name)
|
||||
self.assertIsNone(a.header())
|
||||
@@ -19,12 +20,14 @@ class ReportTest(TestBase):
|
||||
|
||||
def test_report_with_folder(self):
|
||||
report = Report(
|
||||
file_name="results/results_accuracy_STree_iMac27_2021-09-30_11:"
|
||||
"42:07_0.json"
|
||||
file_name=os.path.join(
|
||||
"results",
|
||||
"results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json",
|
||||
)
|
||||
)
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "report.test")
|
||||
self.check_output_file(fake_out, "report")
|
||||
|
||||
def test_report_without_folder(self):
|
||||
report = Report(
|
||||
@@ -33,7 +36,7 @@ class ReportTest(TestBase):
|
||||
)
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "report.test")
|
||||
self.check_output_file(fake_out, "report")
|
||||
|
||||
def test_report_compared(self):
|
||||
report = Report(
|
||||
@@ -43,7 +46,7 @@ class ReportTest(TestBase):
|
||||
)
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "report_compared.test")
|
||||
self.check_output_file(fake_out, "report_compared")
|
||||
|
||||
def test_compute_status(self):
|
||||
file_name = "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json"
|
||||
@@ -66,22 +69,22 @@ class ReportTest(TestBase):
|
||||
report = ReportBest("accuracy", "STree", best=True, grid=False)
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "report_best.test")
|
||||
self.check_output_file(fake_out, "report_best")
|
||||
|
||||
def test_report_grid(self):
|
||||
report = ReportBest("accuracy", "STree", best=False, grid=True)
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "report_grid.test")
|
||||
self.check_output_file(fake_out, "report_grid")
|
||||
|
||||
def test_report_best_both(self):
|
||||
report = ReportBest("accuracy", "STree", best=True, grid=True)
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.report()
|
||||
self.check_output_file(fake_out, "report_best.test")
|
||||
self.check_output_file(fake_out, "report_best")
|
||||
|
||||
@patch("sys.stdout", new_callable=StringIO)
|
||||
def test_report_datasets(self, mock_output):
|
||||
report = ReportDatasets()
|
||||
report.report()
|
||||
self.check_output_file(mock_output, "report_datasets.test")
|
||||
self.check_output_file(mock_output, "report_datasets")
|
||||
|
@@ -19,4 +19,4 @@ class SQLTest(TestBase):
|
||||
file_name = os.path.join(
|
||||
Folders.results, file_name.replace(".json", ".sql")
|
||||
)
|
||||
self.check_file_file(file_name, "sql.test")
|
||||
self.check_file_file(file_name, "sql")
|
||||
|
@@ -132,28 +132,28 @@ class SummaryTest(TestBase):
|
||||
report.acquire()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.list_results(model="STree")
|
||||
self.check_output_file(fake_out, "summary_list_model.test")
|
||||
self.check_output_file(fake_out, "summary_list_model")
|
||||
|
||||
def test_summary_list_results_score(self):
|
||||
report = Summary()
|
||||
report.acquire()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.list_results(score="accuracy")
|
||||
self.check_output_file(fake_out, "summary_list_score.test")
|
||||
self.check_output_file(fake_out, "summary_list_score")
|
||||
|
||||
def test_summary_list_results_n(self):
|
||||
report = Summary()
|
||||
report.acquire()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.list_results(score="accuracy", number=3)
|
||||
self.check_output_file(fake_out, "summary_list_n.test")
|
||||
self.check_output_file(fake_out, "summary_list_n")
|
||||
|
||||
def test_summary_list_hidden(self):
|
||||
report = Summary(hidden=True)
|
||||
report.acquire()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.list_results(score="accuracy")
|
||||
self.check_output_file(fake_out, "summary_list_hidden.test")
|
||||
self.check_output_file(fake_out, "summary_list_hidden")
|
||||
|
||||
def test_show_result_no_title(self):
|
||||
report = Summary()
|
||||
@@ -164,7 +164,7 @@ class SummaryTest(TestBase):
|
||||
criterion="model", value="STree", score="accuracy"
|
||||
)
|
||||
report.show_result(data=best, title=title)
|
||||
self.check_output_file(fake_out, "summary_show_results.test")
|
||||
self.check_output_file(fake_out, "summary_show_results")
|
||||
|
||||
def test_show_result_title(self):
|
||||
report = Summary()
|
||||
@@ -175,7 +175,7 @@ class SummaryTest(TestBase):
|
||||
criterion="model", value="STree", score="accuracy"
|
||||
)
|
||||
report.show_result(data=best, title=title)
|
||||
self.check_output_file(fake_out, "summary_show_results_title.test")
|
||||
self.check_output_file(fake_out, "summary_show_results_title")
|
||||
|
||||
def test_show_result_no_data(self):
|
||||
report = Summary()
|
||||
@@ -214,7 +214,7 @@ class SummaryTest(TestBase):
|
||||
report.acquire()
|
||||
with patch(self.output, new=StringIO()) as fake_out:
|
||||
report.show_top()
|
||||
self.check_output_file(fake_out, "summary_show_top.test")
|
||||
self.check_output_file(fake_out, "summary_show_top")
|
||||
|
||||
@patch("sys.stdout", new_callable=StringIO)
|
||||
def test_show_top_no_data(self, fake_out):
|
||||
|
@@ -31,6 +31,7 @@ class TestBase(unittest.TestCase):
|
||||
print(f'{row};{col};"{value}"', file=f)
|
||||
|
||||
def check_excel_sheet(self, sheet, file_name):
|
||||
file_name += ".test"
|
||||
with open(os.path.join(self.test_files, file_name), "r") as f:
|
||||
expected = csv.reader(f, delimiter=";")
|
||||
for row, col, value in expected:
|
||||
@@ -44,6 +45,7 @@ class TestBase(unittest.TestCase):
|
||||
self.assertEqual(sheet.cell(int(row), int(col)).value, value)
|
||||
|
||||
def check_output_file(self, output, file_name):
|
||||
file_name += ".test"
|
||||
with open(os.path.join(self.test_files, file_name)) as f:
|
||||
expected = f.read()
|
||||
self.assertEqual(output.getvalue(), expected)
|
||||
@@ -51,6 +53,7 @@ class TestBase(unittest.TestCase):
|
||||
def check_file_file(self, computed_file, expected_file):
|
||||
with open(computed_file) as f:
|
||||
computed = f.read()
|
||||
expected_file += ".test"
|
||||
with open(os.path.join(self.test_files, expected_file)) as f:
|
||||
expected = f.read()
|
||||
self.assertEqual(computed, expected)
|
||||
|
@@ -11,8 +11,9 @@ from .Benchmark_test import BenchmarkTest
|
||||
from .Summary_test import SummaryTest
|
||||
from .PairCheck_test import PairCheckTest
|
||||
from .Arguments_test import ArgumentsTest
|
||||
from .scripts.Pair_check_test import BePairCheckTest
|
||||
from .scripts.List_test import ListTest
|
||||
from .scripts.Be_Pair_check_test import BePairCheckTest
|
||||
from .scripts.Be_List_test import BeListTest
|
||||
from .scripts.Be_Report_test import BeReportTest
|
||||
|
||||
all = [
|
||||
"UtilTest",
|
||||
@@ -29,5 +30,6 @@ all = [
|
||||
"PairCheckTest",
|
||||
"ArgumentsTest",
|
||||
"BePairCheckTest",
|
||||
"ListTest",
|
||||
"BeListTest",
|
||||
"BeReportTest",
|
||||
]
|
||||
|
@@ -1,14 +1,14 @@
|
||||
from ..TestBase import TestBase
|
||||
|
||||
|
||||
class ListTest(TestBase):
|
||||
class BeListTest(TestBase):
|
||||
def setUp(self):
|
||||
self.prepare_scripts_env()
|
||||
|
||||
def test_be_list(self):
|
||||
stdout, stderr = self.execute_script("be_list", ["-m", "STree"])
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "summary_list_model.test")
|
||||
self.check_output_file(stdout, "summary_list_model")
|
||||
|
||||
def test_be_list_no_data(self):
|
||||
stdout, stderr = self.execute_script(
|
@@ -10,7 +10,7 @@ class BePairCheckTest(TestBase):
|
||||
"be_pair_check", ["-m1", "ODTE", "-m2", "STree"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "paircheck.test")
|
||||
self.check_output_file(stdout, "paircheck")
|
||||
|
||||
def test_be_pair_check_no_data_a(self):
|
||||
stdout, stderr = self.execute_script(
|
25
benchmark/tests/scripts/Be_Report_test.py
Normal file
25
benchmark/tests/scripts/Be_Report_test.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
from ..TestBase import TestBase
|
||||
|
||||
|
||||
class BeReportTest(TestBase):
|
||||
def setUp(self):
|
||||
self.prepare_scripts_env()
|
||||
|
||||
def test_be_report(self):
|
||||
file_name = os.path.join(
|
||||
"results",
|
||||
"results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json",
|
||||
)
|
||||
stdout, stderr = self.execute_script("be_report", ["-f", file_name])
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "report")
|
||||
|
||||
def test_be_report_compare(self):
|
||||
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
|
||||
stdout, stderr = self.execute_script(
|
||||
"be_report",
|
||||
["-f", file_name, "-c", "1"],
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "report_compared")
|
Reference in New Issue
Block a user