Refactor tests

This commit is contained in:
2022-04-29 17:06:11 +02:00
parent 3f4a04ab50
commit a719098154
13 changed files with 127 additions and 221 deletions

View File

@@ -1,19 +1,14 @@
import os import os
import unittest
import shutil import shutil
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from openpyxl import load_workbook from openpyxl import load_workbook
from .TestBase import TestBase
from ..Utils import Folders, Files from ..Utils import Folders, Files
from ..Results import Benchmark from ..Results import Benchmark
from .Excel_test import ExcelTest
class BenchmarkTest(unittest.TestCase): class BenchmarkTest(TestBase):
def __init__(self, *args, **kwargs):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
super().__init__(*args, **kwargs)
def tearDown(self) -> None: def tearDown(self) -> None:
benchmark = Benchmark("accuracy", visualize=False) benchmark = Benchmark("accuracy", visualize=False)
files = [ files = [
@@ -41,29 +36,25 @@ class BenchmarkTest(unittest.TestCase):
benchmark = Benchmark("accuracy", visualize=False) benchmark = Benchmark("accuracy", visualize=False)
benchmark.compile_results() benchmark.compile_results()
benchmark.save_results() benchmark.save_results()
with open(benchmark.get_result_file_name()) as f: self.check_file_file(
computed = f.readlines() benchmark.get_result_file_name(), "exreport_csv.test"
with open(os.path.join("test_files", "exreport_csv.test")) as f_exp: )
expected = f_exp.readlines()
self.assertEqual(computed, expected)
def test_exreport_report(self): def test_exreport_report(self):
benchmark = Benchmark("accuracy", visualize=False) benchmark = Benchmark("accuracy", visualize=False)
benchmark.compile_results() benchmark.compile_results()
benchmark.save_results() benchmark.save_results()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
benchmark.report(tex_output=False) benchmark.report(tex_output=False)
with open(os.path.join("test_files", "exreport_report.test")) as f: self.check_output_file(fake_out, "exreport_report.test")
expected = f.read()
self.assertEqual(fake_out.getvalue(), expected)
def test_exreport(self): def test_exreport(self):
benchmark = Benchmark("accuracy", visualize=False) benchmark = Benchmark("accuracy", visualize=False)
benchmark.compile_results() benchmark.compile_results()
benchmark.save_results() benchmark.save_results()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
benchmark.exreport() benchmark.exreport()
with open(os.path.join("test_files", "exreport.test")) as f: with open(os.path.join(self.test_files, "exreport.test")) as f:
expected_t = f.read() expected_t = f.read()
computed_t = fake_out.getvalue() computed_t = fake_out.getvalue()
computed_t = computed_t.split("\n") computed_t = computed_t.split("\n")
@@ -80,7 +71,7 @@ class BenchmarkTest(unittest.TestCase):
benchmark = Benchmark("accuracy", visualize=False) benchmark = Benchmark("accuracy", visualize=False)
benchmark.compile_results() benchmark.compile_results()
benchmark.save_results() benchmark.save_results()
with patch("sys.stdout", new=StringIO()): with patch(self.output, new=StringIO()):
benchmark.exreport() benchmark.exreport()
self.assertFalse(os.path.exists(Files.exreport_pdf)) self.assertFalse(os.path.exists(Files.exreport_pdf))
self.assertFalse(os.path.exists(Folders.report)) self.assertFalse(os.path.exists(Folders.report))
@@ -89,43 +80,34 @@ class BenchmarkTest(unittest.TestCase):
benchmark = Benchmark("unknown", visualize=False) benchmark = Benchmark("unknown", visualize=False)
benchmark.compile_results() benchmark.compile_results()
benchmark.save_results() benchmark.save_results()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
benchmark.exreport() benchmark.exreport()
computed = fake_out.getvalue() self.check_output_file(fake_out, "exreport_error.test")
with open(os.path.join("test_files", "exreport_error.test")) as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_tex_output(self): def test_tex_output(self):
benchmark = Benchmark("accuracy", visualize=False) benchmark = Benchmark("accuracy", visualize=False)
benchmark.compile_results() benchmark.compile_results()
benchmark.save_results() benchmark.save_results()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
benchmark.report(tex_output=True) benchmark.report(tex_output=True)
with open(os.path.join("test_files", "exreport_report.test")) as f: with open(os.path.join(self.test_files, "exreport_report.test")) as f:
expected = f.read() expected = f.read()
self.assertEqual(fake_out.getvalue(), expected) self.assertEqual(fake_out.getvalue(), expected)
self.assertTrue(os.path.exists(benchmark.get_tex_file())) self.assertTrue(os.path.exists(benchmark.get_tex_file()))
with open(benchmark.get_tex_file()) as f: self.check_file_file(benchmark.get_tex_file(), "exreport_tex.test")
computed = f.read()
with open(os.path.join("test_files", "exreport_tex.test")) as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_excel_output(self): def test_excel_output(self):
benchmark = Benchmark("accuracy", visualize=False) benchmark = Benchmark("accuracy", visualize=False)
benchmark.compile_results() benchmark.compile_results()
benchmark.save_results() benchmark.save_results()
with patch("sys.stdout", new=StringIO()): with patch(self.output, new=StringIO()):
benchmark.exreport() benchmark.exreport()
benchmark.excel() benchmark.excel()
file_name = benchmark.get_excel_file_name() file_name = benchmark.get_excel_file_name()
book = load_workbook(file_name) book = load_workbook(file_name)
for sheet_name in book.sheetnames: for sheet_name in book.sheetnames:
sheet = book[sheet_name] sheet = book[sheet_name]
ExcelTest.check_excel_sheet( self.check_excel_sheet(sheet, f"exreport_excel_{sheet_name}.test")
self, sheet, f"exreport_excel_{sheet_name}.test"
)
# ExcelTest.generate_excel_sheet( # ExcelTest.generate_excel_sheet(
# self, sheet, f"exreport_excel_{sheet_name}.test" # self, sheet, f"exreport_excel_{sheet_name}.test"
# ) # )

View File

@@ -1,13 +1,9 @@
import os import os
import unittest from .TestBase import TestBase
from ..Experiments import BestResults, Datasets from ..Experiments import BestResults, Datasets
class BestResultTest(unittest.TestCase): class BestResultTest(TestBase):
def __init__(self, *args, **kwargs):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
super().__init__(*args, **kwargs)
def test_load(self): def test_load(self):
expected = { expected = {
"balance-scale": [ "balance-scale": [

View File

@@ -1,20 +1,16 @@
import os
import shutil import shutil
import unittest from .TestBase import TestBase
from ..Experiments import Randomized, Datasets from ..Experiments import Randomized, Datasets
class DatasetTest(unittest.TestCase): class DatasetTest(TestBase):
def __init__(self, *args, **kwargs): def setUp(self):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
self.datasets_values = { self.datasets_values = {
"balance-scale": (625, 4, 3), "balance-scale": (625, 4, 3),
"balloons": (16, 4, 2), "balloons": (16, 4, 2),
"iris": (150, 4, 3), "iris": (150, 4, 3),
"wine": (178, 13, 3), "wine": (178, 13, 3),
} }
super().__init__(*args, **kwargs)
def tearDown(self) -> None: def tearDown(self) -> None:
self.set_env(".env.dist") self.set_env(".env.dist")

View File

@@ -1,17 +1,12 @@
import os import os
import csv
import unittest
from openpyxl import load_workbook from openpyxl import load_workbook
from xlsxwriter import Workbook from xlsxwriter import Workbook
from .TestBase import TestBase
from ..Results import Excel from ..Results import Excel
from ..Utils import Folders from ..Utils import Folders
class ExcelTest(unittest.TestCase): class ExcelTest(TestBase):
def __init__(self, *args, **kwargs):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
super().__init__(*args, **kwargs)
def tearDown(self) -> None: def tearDown(self) -> None:
files = [ files = [
"results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.xlsx", "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.xlsx",
@@ -24,29 +19,6 @@ class ExcelTest(unittest.TestCase):
os.remove(file_name) os.remove(file_name)
return super().tearDown() return super().tearDown()
@staticmethod
def generate_excel_sheet(test, sheet, file_name):
with open(os.path.join("test_files", file_name), "w") as f:
for row in range(1, sheet.max_row + 1):
for col in range(1, sheet.max_column + 1):
value = sheet.cell(row=row, column=col).value
if value is not None:
print(f'{row};{col};"{value}"', file=f)
@staticmethod
def check_excel_sheet(test, sheet, file_name):
with open(os.path.join("test_files", file_name), "r") as f:
expected = csv.reader(f, delimiter=";")
for row, col, value in expected:
if value.isdigit():
value = int(value)
else:
try:
value = float(value)
except ValueError:
pass
test.assertEqual(sheet.cell(int(row), int(col)).value, value)
def test_report_excel_compared(self): def test_report_excel_compared(self):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
report = Excel(file_name, compare=True) report = Excel(file_name, compare=True)
@@ -54,7 +26,7 @@ class ExcelTest(unittest.TestCase):
file_output = report.get_file_name() file_output = report.get_file_name()
book = load_workbook(file_output) book = load_workbook(file_output)
sheet = book["STree"] sheet = book["STree"]
self.check_excel_sheet(self, sheet, "excel_compared.test") self.check_excel_sheet(sheet, "excel_compared.test")
def test_report_excel(self): def test_report_excel(self):
file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json" file_name = "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"
@@ -63,7 +35,7 @@ class ExcelTest(unittest.TestCase):
file_output = report.get_file_name() file_output = report.get_file_name()
book = load_workbook(file_output) book = load_workbook(file_output)
sheet = book["STree"] sheet = book["STree"]
self.check_excel_sheet(self, sheet, "excel.test") self.check_excel_sheet(sheet, "excel.test")
def test_Excel_Add_sheet(self): def test_Excel_Add_sheet(self):
file_name = "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json" file_name = "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json"
@@ -79,6 +51,6 @@ class ExcelTest(unittest.TestCase):
book.close() book.close()
book = load_workbook(os.path.join(Folders.results, excel_file_name)) book = load_workbook(os.path.join(Folders.results, excel_file_name))
sheet = book["STree"] sheet = book["STree"]
self.check_excel_sheet(self, sheet, "excel_add_STree.test") self.check_excel_sheet(sheet, "excel_add_STree.test")
sheet = book["ODTE"] sheet = book["ODTE"]
self.check_excel_sheet(self, sheet, "excel_add_ODTE.test") self.check_excel_sheet(sheet, "excel_add_ODTE.test")

View File

@@ -1,14 +1,12 @@
import os import os
import json import json
import unittest from .TestBase import TestBase
from ..Experiments import Experiment, Datasets from ..Experiments import Experiment, Datasets
class ExperimentTest(unittest.TestCase): class ExperimentTest(TestBase):
def __init__(self, *args, **kwargs): def setUp(self):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
self.exp = self.build_exp() self.exp = self.build_exp()
super().__init__(*args, **kwargs)
def build_exp(self, hyperparams=False, grid=False): def build_exp(self, hyperparams=False, grid=False):
params = { params = {

View File

@@ -1,14 +1,12 @@
import os import os
import json import json
import unittest from .TestBase import TestBase
from ..Experiments import GridSearch, Datasets from ..Experiments import GridSearch, Datasets
class GridSearchTest(unittest.TestCase): class GridSearchTest(TestBase):
def __init__(self, *args, **kwargs): def setUp(self):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
self.grid = self.build_exp() self.grid = self.build_exp()
super().__init__(*args, **kwargs)
def tearDown(self) -> None: def tearDown(self) -> None:
grid = self.build_exp() grid = self.build_exp()

View File

@@ -1,4 +1,3 @@
import unittest
import warnings import warnings
from sklearn.exceptions import ConvergenceWarning from sklearn.exceptions import ConvergenceWarning
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
@@ -12,10 +11,11 @@ from sklearn.datasets import load_wine
from stree import Stree from stree import Stree
from wodt import Wodt from wodt import Wodt
from odte import Odte from odte import Odte
from .TestBase import TestBase
from ..Models import Models from ..Models import Models
class ModelTest(unittest.TestCase): class ModelTest(TestBase):
def test_Models(self): def test_Models(self):
test = { test = {
"STree": Stree, "STree": Stree,

View File

@@ -1,15 +1,11 @@
import os import os
import unittest
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from .TestBase import TestBase
from ..Results import PairCheck from ..Results import PairCheck
class PairCheckTest(unittest.TestCase): class PairCheckTest(TestBase):
def __init__(self, *args, **kwargs):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
super().__init__(*args, **kwargs)
def build_model( def build_model(
self, self,
score="accuracy", score="accuracy",
@@ -23,46 +19,35 @@ class PairCheckTest(unittest.TestCase):
def test_pair_check(self): def test_pair_check(self):
report = self.build_model(model1="ODTE", model2="STree") report = self.build_model(model1="ODTE", model2="STree")
report.compute() report.compute()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
computed = fake_out.getvalue() computed = fake_out.getvalue()
with open(os.path.join("test_files", "paircheck.test"), "r") as f: with open(os.path.join(self.test_files, "paircheck.test"), "r") as f:
expected = f.read() expected = f.read()
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def test_pair_check_win(self): def test_pair_check_win(self):
report = self.build_model(win=True) report = self.build_model(win=True)
report.compute() report.compute()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
computed = fake_out.getvalue() self.check_output_file(fake_out, "paircheck_win.test")
with open(os.path.join("test_files", "paircheck_win.test"), "r") as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_pair_check_lose(self): def test_pair_check_lose(self):
report = self.build_model( report = self.build_model(
model1="RandomForest", model2="STree", lose=True model1="RandomForest", model2="STree", lose=True
) )
report.compute() report.compute()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
computed = fake_out.getvalue() self.check_output_file(fake_out, "paircheck_lose.test")
with open(os.path.join("test_files", "paircheck_lose.test"), "r") as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_pair_check_win_lose(self): def test_pair_check_win_lose(self):
report = self.build_model(win=True, lose=True) report = self.build_model(win=True, lose=True)
report.compute() report.compute()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
computed = fake_out.getvalue() self.check_output_file(fake_out, "paircheck_win_lose.test")
with open(
os.path.join("test_files", "paircheck_win_lose.test"), "r"
) as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_pair_check_store_result(self): def test_pair_check_store_result(self):
report = self.build_model(win=True, lose=True) report = self.build_model(win=True, lose=True)

View File

@@ -1,16 +1,11 @@
import os
import unittest
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from .TestBase import TestBase
from ..Results import Report, BaseReport, ReportBest from ..Results import Report, BaseReport, ReportBest
from ..Utils import Symbols from ..Utils import Symbols
class ReportTest(unittest.TestCase): class ReportTest(TestBase):
def __init__(self, *args, **kwargs):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
super().__init__(*args, **kwargs)
def test_BaseReport(self): def test_BaseReport(self):
with patch.multiple(BaseReport, __abstractmethods__=set()): with patch.multiple(BaseReport, __abstractmethods__=set()):
file_name = ( file_name = (
@@ -27,22 +22,18 @@ class ReportTest(unittest.TestCase):
file_name="results/results_accuracy_STree_iMac27_2021-09-30_11:" file_name="results/results_accuracy_STree_iMac27_2021-09-30_11:"
"42:07_0.json" "42:07_0.json"
) )
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
with open("test_files/report.test", "r") as f: self.check_output_file(fake_out, "report.test")
expected = f.read()
self.assertEqual(fake_out.getvalue(), expected)
def test_report_without_folder(self): def test_report_without_folder(self):
report = Report( report = Report(
file_name="results_accuracy_STree_iMac27_2021-09-30_11:42:07_0" file_name="results_accuracy_STree_iMac27_2021-09-30_11:42:07_0"
".json" ".json"
) )
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
with open("test_files/report.test", "r") as f: self.check_output_file(fake_out, "report.test")
expected = f.read()
self.assertEqual(fake_out.getvalue(), expected)
def test_report_compared(self): def test_report_compared(self):
report = Report( report = Report(
@@ -50,11 +41,9 @@ class ReportTest(unittest.TestCase):
".json", ".json",
compare=True, compare=True,
) )
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
with open("test_files/report_compared.test", "r") as f: self.check_output_file(fake_out, "report_compared.test")
expected = f.read()
self.assertEqual(fake_out.getvalue(), expected)
def test_compute_status(self): def test_compute_status(self):
file_name = "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json" file_name = "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json"
@@ -62,7 +51,7 @@ class ReportTest(unittest.TestCase):
file_name=file_name, file_name=file_name,
compare=True, compare=True,
) )
with patch("sys.stdout", new=StringIO()): with patch(self.output, new=StringIO()):
report.report() report.report()
res = report._compute_status("balloons", 0.99) res = report._compute_status("balloons", 0.99)
self.assertEqual(res, Symbols.better_best) self.assertEqual(res, Symbols.better_best)
@@ -75,25 +64,18 @@ class ReportTest(unittest.TestCase):
def test_report_best(self): def test_report_best(self):
report = ReportBest("accuracy", "STree", best=True, grid=False) report = ReportBest("accuracy", "STree", best=True, grid=False)
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
with open("test_files/report_best.test", "r") as f: self.check_output_file(fake_out, "report_best.test")
expected = f.read()
self.assertEqual(fake_out.getvalue(), expected)
def test_report_grid(self): def test_report_grid(self):
report = ReportBest("accuracy", "STree", best=False, grid=True) report = ReportBest("accuracy", "STree", best=False, grid=True)
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
with open("test_files/report_grid.test", "r") as f: self.check_output_file(fake_out, "report_grid.test")
expected = f.read()
self.assertEqual(fake_out.getvalue(), expected)
def test_report_best_both(self): def test_report_best_both(self):
report = ReportBest("accuracy", "STree", best=True, grid=True) report = ReportBest("accuracy", "STree", best=True, grid=True)
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.report() report.report()
with open("test_files/report_best.test", "r") as f: self.check_output_file(fake_out, "report_best.test")
expected = f.read()
self.assertEqual(fake_out.getvalue(), expected)

View File

@@ -1,14 +1,10 @@
import os import os
import unittest from .TestBase import TestBase
from ..Results import SQL from ..Results import SQL
from ..Utils import Folders from ..Utils import Folders
class SQLTest(unittest.TestCase): class SQLTest(TestBase):
def __init__(self, *args, **kwargs):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
super().__init__(*args, **kwargs)
def tearDown(self) -> None: def tearDown(self) -> None:
files = [ files = [
"results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.sql", "results_accuracy_ODTE_Galgo_2022-04-20_10:52:20_0.sql",
@@ -26,9 +22,4 @@ class SQLTest(unittest.TestCase):
file_name = os.path.join( file_name = os.path.join(
Folders.results, file_name.replace(".json", ".sql") Folders.results, file_name.replace(".json", ".sql")
) )
self.check_file_file(file_name, "sql.test")
with open(file_name, "r") as file:
computed = file.read()
with open(os.path.join("test_files", "sql.test")) as f_exp:
expected = f_exp.read()
self.assertEqual(computed, expected)

View File

@@ -1,15 +1,10 @@
import os
import unittest
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from .TestBase import TestBase
from ..Results import Summary from ..Results import Summary
class SummaryTest(unittest.TestCase): class SummaryTest(TestBase):
def __init__(self, *args, **kwargs):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
super().__init__(*args, **kwargs)
def test_summary_without_model(self): def test_summary_without_model(self):
report = Summary() report = Summary()
report.acquire() report.acquire()
@@ -135,85 +130,57 @@ class SummaryTest(unittest.TestCase):
def test_summary_list_results_model(self): def test_summary_list_results_model(self):
report = Summary() report = Summary()
report.acquire() report.acquire()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.list_results(model="STree") report.list_results(model="STree")
computed = fake_out.getvalue() self.check_output_file(fake_out, "summary_list_model.test")
with open(
os.path.join("test_files", "summary_list_model.test"), "r"
) as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_summary_list_results_score(self): def test_summary_list_results_score(self):
report = Summary() report = Summary()
report.acquire() report.acquire()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.list_results(score="accuracy") report.list_results(score="accuracy")
computed = fake_out.getvalue() self.check_output_file(fake_out, "summary_list_score.test")
with open(
os.path.join("test_files", "summary_list_score.test"), "r"
) as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_summary_list_results_n(self): def test_summary_list_results_n(self):
report = Summary() report = Summary()
report.acquire() report.acquire()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.list_results(score="accuracy", number=3) report.list_results(score="accuracy", number=3)
computed = fake_out.getvalue() self.check_output_file(fake_out, "summary_list_n.test")
with open(os.path.join("test_files", "summary_list_n.test"), "r") as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_summary_list_hiden(self): def test_summary_list_hiden(self):
report = Summary(hidden=True) report = Summary(hidden=True)
report.acquire() report.acquire()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.list_results(score="accuracy") report.list_results(score="accuracy")
computed = fake_out.getvalue() self.check_output_file(fake_out, "summary_list_hidden.test")
with open(
os.path.join("test_files", "summary_list_hidden.test"), "r"
) as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_show_result_no_title(self): def test_show_result_no_title(self):
report = Summary() report = Summary()
report.acquire() report.acquire()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
title = "" title = ""
best = report.best_result( best = report.best_result(
criterion="model", value="STree", score="accuracy" criterion="model", value="STree", score="accuracy"
) )
report.show_result(data=best, title=title) report.show_result(data=best, title=title)
computed = fake_out.getvalue() self.check_output_file(fake_out, "summary_show_results.test")
with open(
os.path.join("test_files", "summary_show_results.test"), "r"
) as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_show_result_title(self): def test_show_result_title(self):
report = Summary() report = Summary()
report.acquire() report.acquire()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
title = "**Title**" title = "**Title**"
best = report.best_result( best = report.best_result(
criterion="model", value="STree", score="accuracy" criterion="model", value="STree", score="accuracy"
) )
report.show_result(data=best, title=title) report.show_result(data=best, title=title)
computed = fake_out.getvalue() self.check_output_file(fake_out, "summary_show_results_title.test")
with open(
os.path.join("test_files", "summary_show_results_title.test"), "r"
) as f:
expected = f.read()
self.assertEqual(computed, expected)
def test_show_result_no_data(self): def test_show_result_no_data(self):
report = Summary() report = Summary()
report.acquire() report.acquire()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
title = "**Test**" title = "**Test**"
report.show_result(data={}, title=title) report.show_result(data={}, title=title)
computed = fake_out.getvalue() computed = fake_out.getvalue()
@@ -245,11 +212,6 @@ class SummaryTest(unittest.TestCase):
def test_show_top(self): def test_show_top(self):
report = Summary() report = Summary()
report.acquire() report.acquire()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch(self.output, new=StringIO()) as fake_out:
report.show_top() report.show_top()
computed = fake_out.getvalue() self.check_output_file(fake_out, "summary_show_top.test")
with open(
os.path.join("test_files", "summary_show_top.test"), "r"
) as f:
expected = f.read()
self.assertEqual(computed, expected)

View File

@@ -0,0 +1,44 @@
import os
import csv
import unittest
class TestBase(unittest.TestCase):
def __init__(self, *args, **kwargs):
os.chdir(os.path.dirname(os.path.abspath(__file__)))
self.test_files = "test_files"
self.output = "sys.stdout"
super().__init__(*args, **kwargs)
def generate_excel_sheet(self, sheet, file_name):
with open(os.path.join(self.test_files, file_name), "w") as f:
for row in range(1, sheet.max_row + 1):
for col in range(1, sheet.max_column + 1):
value = sheet.cell(row=row, column=col).value
if value is not None:
print(f'{row};{col};"{value}"', file=f)
def check_excel_sheet(self, sheet, file_name):
with open(os.path.join(self.test_files, file_name), "r") as f:
expected = csv.reader(f, delimiter=";")
for row, col, value in expected:
if value.isdigit():
value = int(value)
else:
try:
value = float(value)
except ValueError:
pass
self.assertEqual(sheet.cell(int(row), int(col)).value, value)
def check_output_file(self, output, file_name):
with open(os.path.join(self.test_files, file_name)) as f:
expected = f.read()
self.assertEqual(output.getvalue(), expected)
def check_file_file(self, computed_file, expected_file):
with open(computed_file) as f:
computed = f.read()
with open(os.path.join(self.test_files, expected_file)) as f:
expected = f.read()
self.assertEqual(computed, expected)

View File

@@ -1,11 +1,11 @@
import os import os
import sys import sys
import unittest
import argparse import argparse
from .TestBase import TestBase
from ..Utils import Folders, Files, Symbols, TextColor, EnvData, EnvDefault from ..Utils import Folders, Files, Symbols, TextColor, EnvData, EnvDefault
class UtilTest(unittest.TestCase): class UtilTest(TestBase):
def test_Folders(self): def test_Folders(self):
self.assertEqual("results", Folders.results) self.assertEqual("results", Folders.results)
self.assertEqual("hidden_results", Folders.hidden_results) self.assertEqual("hidden_results", Folders.hidden_results)