From 34b25756ea312be7bc5ccfcc471b70f05d31fa0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Mon, 24 Oct 2022 19:05:13 +0200 Subject: [PATCH] Fix error in tests with STree fixed version --- benchmark/tests/GridSearch_test.py | 4 +++- benchmark/tests/Report_test.py | 12 +++++++++++- benchmark/tests/TestBase.py | 5 +++++ .../tests/results/grid_output_accuracy_STree.json | 4 ++-- benchmark/tests/scripts/Be_Report_test.py | 11 ++++++++++- 5 files changed, 31 insertions(+), 5 deletions(-) diff --git a/benchmark/tests/GridSearch_test.py b/benchmark/tests/GridSearch_test.py index 4cfb0f6..d003d6d 100644 --- a/benchmark/tests/GridSearch_test.py +++ b/benchmark/tests/GridSearch_test.py @@ -77,7 +77,9 @@ class GridSearchTest(TestBase): "v. 1.2.4, Computed on Test on 2022-02-22 at 12:00:00 took 1s", ], } - self.assertSequenceEqual(data, expected) + for key, value in expected.items(): + self.assertEqual(data[key][0], value[0]) + self.assertDictEqual(data[key][1], value[1]) def test_duration_message(self): expected = ["47.234s", "5.421m", "1.177h"] diff --git a/benchmark/tests/Report_test.py b/benchmark/tests/Report_test.py index 24961bd..e2ae041 100644 --- a/benchmark/tests/Report_test.py +++ b/benchmark/tests/Report_test.py @@ -75,7 +75,17 @@ class ReportTest(TestBase): report = ReportBest("accuracy", "STree", best=False, grid=True) with patch(self.output, new=StringIO()) as stdout: report.report() - self.check_output_file(stdout, "report_grid") + file_name = "report_grid.test" + with open(os.path.join(self.test_files, file_name)) as f: + expected = f.read().splitlines() + output_text = stdout.getvalue().splitlines() + # Compare replacing STree version + for line, index in zip(expected, range(len(expected))): + if "1.2.4" in line: + # replace STree version + line = self.replace_STree_version(line, output_text, index) + + self.assertEqual(line, output_text[index]) def test_report_best_both(self): report = ReportBest("accuracy", "STree", best=True, grid=True) diff --git a/benchmark/tests/TestBase.py b/benchmark/tests/TestBase.py index 562700c..af33d8a 100644 --- a/benchmark/tests/TestBase.py +++ b/benchmark/tests/TestBase.py @@ -50,6 +50,11 @@ class TestBase(unittest.TestCase): expected = f.read() self.assertEqual(output.getvalue(), expected) + @staticmethod + def replace_STree_version(line, output, index): + idx = line.find("1.2.4") + return line.replace("1.2.4", output[index][idx : idx + 5]) + def check_file_file(self, computed_file, expected_file): with open(computed_file) as f: computed = f.read() diff --git a/benchmark/tests/results/grid_output_accuracy_STree.json b/benchmark/tests/results/grid_output_accuracy_STree.json index 7f197d6..731e0b7 100644 --- a/benchmark/tests/results/grid_output_accuracy_STree.json +++ b/benchmark/tests/results/grid_output_accuracy_STree.json @@ -6,7 +6,7 @@ "kernel": "liblinear", "multiclass_strategy": "ovr" }, - "v. 1.2.4, Computed on Test on 2022-02-22 at 12:00:00 took 1s" + "v. 1.3.0, Computed on Test on 2022-02-22 at 12:00:00 took 1s" ], "balloons": [ 0.625, @@ -15,6 +15,6 @@ "kernel": "linear", "multiclass_strategy": "ovr" }, - "v. 1.2.4, Computed on Test on 2022-02-22 at 12:00:00 took 1s" + "v. 1.3.0, Computed on Test on 2022-02-22 at 12:00:00 took 1s" ] } \ No newline at end of file diff --git a/benchmark/tests/scripts/Be_Report_test.py b/benchmark/tests/scripts/Be_Report_test.py index 5fe2680..14a51f8 100644 --- a/benchmark/tests/scripts/Be_Report_test.py +++ b/benchmark/tests/scripts/Be_Report_test.py @@ -55,7 +55,16 @@ class BeReportTest(TestBase): "be_report", ["-s", "accuracy", "-m", "STree", "-g", "1"] ) self.assertEqual(stderr.getvalue(), "") - self.check_output_file(stdout, "report_grid") + file_name = "report_grid.test" + with open(os.path.join(self.test_files, file_name)) as f: + expected = f.read().splitlines() + output_text = stdout.getvalue().splitlines() + # Compare replacing STree version + for line, index in zip(expected, range(len(expected))): + if "1.2.4" in line: + # replace STree version + line = self.replace_STree_version(line, output_text, index) + self.assertEqual(line, output_text[index]) def test_be_report_best_both(self): stdout, stderr = self.execute_script(