Begin print_strees_test

This commit is contained in:
2022-05-09 00:30:33 +02:00
parent b3bc2fbd2f
commit 534f32b625
4 changed files with 120 additions and 13 deletions

View File

@@ -58,6 +58,17 @@ class TestBase(unittest.TestCase):
expected = f.read() expected = f.read()
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def check_output_lines(self, stdout, file_name, lines_to_compare):
with open(os.path.join(self.test_files, f"{file_name}.test")) as f:
expected = f.read()
computed_data = stdout.getvalue().splitlines()
n_line = 0
# compare only report lines without date, time, duration...
for expected, computed in zip(expected.splitlines(), computed_data):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def prepare_scripts_env(self): def prepare_scripts_env(self):
self.scripts_folder = os.path.join( self.scripts_folder = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "scripts" os.path.dirname(os.path.abspath(__file__)), "..", "scripts"

View File

@@ -19,6 +19,7 @@ from .scripts.Be_Grid_test import BeGridTest
from .scripts.Be_Best_test import BeBestTest from .scripts.Be_Best_test import BeBestTest
from .scripts.Be_Benchmark_test import BeBenchmarkTest from .scripts.Be_Benchmark_test import BeBenchmarkTest
from .scripts.Be_Main_test import BeMainTest from .scripts.Be_Main_test import BeMainTest
from .scripts.Be_Print_Strees_test import BePrintStrees
all = [ all = [
"UtilTest", "UtilTest",
@@ -42,4 +43,5 @@ all = [
"BeBestTest", "BeBestTest",
"BeBenchmarkTest", "BeBenchmarkTest",
"BeMainTest", "BeMainTest",
"BePrintStrees",
] ]

View File

@@ -1,7 +1,5 @@
import os
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from ...Utils import Folders
from ...Results import Report from ...Results import Report
from ..TestBase import TestBase from ..TestBase import TestBase
@@ -16,17 +14,6 @@ class BeMainTest(TestBase):
self.remove_files(self.files, ".") self.remove_files(self.files, ".")
return super().tearDown() return super().tearDown()
def check_output_lines(self, stdout, file_name, lines_to_compare):
with open(os.path.join(self.test_files, f"{file_name}.test")) as f:
expected = f.read()
computed_data = stdout.getvalue().splitlines()
n_line = 0
# compare only report lines without date, time, duration...
for expected, computed in zip(expected.splitlines(), computed_data):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def test_be_benchmark_dataset(self): def test_be_benchmark_dataset(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",

View File

@@ -0,0 +1,107 @@
from io import StringIO
from unittest.mock import patch
from ...Results import Report
from ..TestBase import TestBase
class BePrintStrees(TestBase):
def setUp(self):
self.prepare_scripts_env()
self.score = "accuracy"
self.files = []
def tearDown(self) -> None:
self.remove_files(self.files, ".")
return super().tearDown()
# def test_be_benchmark_dataset(self):
# stdout, _ = self.execute_script(
# "be_main",
# ["-m", "STree", "-d", "balloons", "--title", "test"],
# )
# self.check_output_lines(
# stdout=stdout,
# file_name="be_main_dataset",
# lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13],
# )
# def test_be_benchmark_complete(self):
# stdout, _ = self.execute_script(
# "be_main",
# ["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"],
# )
# # keep the report name to delete it after
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
# self.files.append(report_name)
# self.check_output_lines(
# stdout, "be_main_complete", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
# )
# def test_be_benchmark_no_report(self):
# stdout, _ = self.execute_script(
# "be_main",
# ["-s", self.score, "-m", "STree", "--title", "test"],
# )
# # keep the report name to delete it after
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
# self.files.append(report_name)
# report = Report(file_name=report_name)
# with patch(self.output, new=StringIO()) as stdout:
# report.report()
# self.check_output_lines(
# stdout,
# "be_main_complete",
# [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14],
# )
# def test_be_benchmark_best_params(self):
# stdout, _ = self.execute_script(
# "be_main",
# [
# "-s",
# self.score,
# "-m",
# "STree",
# "--title",
# "test",
# "-f",
# "1",
# "-r",
# "1",
# ],
# )
# # keep the report name to delete it after
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
# self.files.append(report_name)
# self.check_output_lines(
# stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
# )
# def test_be_benchmark_grid_params(self):
# stdout, _ = self.execute_script(
# "be_main",
# [
# "-s",
# self.score,
# "-m",
# "STree",
# "--title",
# "test",
# "-g",
# "1",
# "-r",
# "1",
# ],
# )
# # keep the report name to delete it after
# report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
# self.files.append(report_name)
# self.check_output_lines(
# stdout, "be_main_grid", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
# )
# def test_be_benchmark_no_data(self):
# stdout, _ = self.execute_script(
# "be_main", ["-m", "STree", "-d", "unknown", "--title", "test"]
# )
# self.assertEqual(stdout.getvalue(), "Unknown dataset: unknown\n")