Begin print_strees_test

This commit is contained in:
2022-05-09 00:30:33 +02:00
parent b3bc2fbd2f
commit 534f32b625
4 changed files with 120 additions and 13 deletions

View File

@@ -58,6 +58,17 @@ class TestBase(unittest.TestCase):
expected = f.read()
self.assertEqual(computed, expected)
def check_output_lines(self, stdout, file_name, lines_to_compare):
with open(os.path.join(self.test_files, f"{file_name}.test")) as f:
expected = f.read()
computed_data = stdout.getvalue().splitlines()
n_line = 0
# compare only report lines without date, time, duration...
for expected, computed in zip(expected.splitlines(), computed_data):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def prepare_scripts_env(self):
self.scripts_folder = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "scripts"

View File

@@ -19,6 +19,7 @@ from .scripts.Be_Grid_test import BeGridTest
from .scripts.Be_Best_test import BeBestTest
from .scripts.Be_Benchmark_test import BeBenchmarkTest
from .scripts.Be_Main_test import BeMainTest
from .scripts.Be_Print_Strees_test import BePrintStrees
all = [
"UtilTest",
@@ -42,4 +43,5 @@ all = [
"BeBestTest",
"BeBenchmarkTest",
"BeMainTest",
"BePrintStrees",
]

View File

@@ -1,7 +1,5 @@
import os
from io import StringIO
from unittest.mock import patch
from ...Utils import Folders
from ...Results import Report
from ..TestBase import TestBase
@@ -16,17 +14,6 @@ class BeMainTest(TestBase):
self.remove_files(self.files, ".")
return super().tearDown()
def check_output_lines(self, stdout, file_name, lines_to_compare):
with open(os.path.join(self.test_files, f"{file_name}.test")) as f:
expected = f.read()
computed_data = stdout.getvalue().splitlines()
n_line = 0
# compare only report lines without date, time, duration...
for expected, computed in zip(expected.splitlines(), computed_data):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def test_be_benchmark_dataset(self):
stdout, _ = self.execute_script(
"be_main",

View File

@@ -0,0 +1,107 @@
from io import StringIO
from unittest.mock import patch
from ...Results import Report
from ..TestBase import TestBase
class BePrintStrees(TestBase):
def setUp(self):
self.prepare_scripts_env()
self.score = "accuracy"
self.files = []
def tearDown(self) -> None:
self.remove_files(self.files, ".")
return super().tearDown()
# def test_be_benchmark_dataset(self):
# stdout, _ = self.execute_script(
# "be_main",
# ["-m", "STree", "-d", "balloons", "--title", "test"],
# )
# self.check_output_lines(
# stdout=stdout,
# file_name="be_main_dataset",
# lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13],
# )
# def test_be_benchmark_complete(self):
# stdout, _ = self.execute_script(
# "be_main",
# ["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"],
# )
# # keep the report name to delete it after
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
# self.files.append(report_name)
# self.check_output_lines(
# stdout, "be_main_complete", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
# )
# def test_be_benchmark_no_report(self):
# stdout, _ = self.execute_script(
# "be_main",
# ["-s", self.score, "-m", "STree", "--title", "test"],
# )
# # keep the report name to delete it after
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
# self.files.append(report_name)
# report = Report(file_name=report_name)
# with patch(self.output, new=StringIO()) as stdout:
# report.report()
# self.check_output_lines(
# stdout,
# "be_main_complete",
# [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14],
# )
# def test_be_benchmark_best_params(self):
# stdout, _ = self.execute_script(
# "be_main",
# [
# "-s",
# self.score,
# "-m",
# "STree",
# "--title",
# "test",
# "-f",
# "1",
# "-r",
# "1",
# ],
# )
# # keep the report name to delete it after
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
# self.files.append(report_name)
# self.check_output_lines(
# stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
# )
# def test_be_benchmark_grid_params(self):
# stdout, _ = self.execute_script(
# "be_main",
# [
# "-s",
# self.score,
# "-m",
# "STree",
# "--title",
# "test",
# "-g",
# "1",
# "-r",
# "1",
# ],
# )
# # keep the report name to delete it after
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
# self.files.append(report_name)
# self.check_output_lines(
# stdout, "be_main_grid", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
# )
# def test_be_benchmark_no_data(self):
# stdout, _ = self.execute_script(
# "be_main", ["-m", "STree", "-d", "unknown", "--title", "test"]
# )
# self.assertEqual(stdout.getvalue(), "Unknown dataset: unknown\n")