Fix error in tests with STree fixed version

This commit is contained in:
2022-10-24 19:05:13 +02:00
parent 12024df4d8
commit 34b25756ea
5 changed files with 31 additions and 5 deletions

View File

@@ -77,7 +77,9 @@ class GridSearchTest(TestBase):
"v. 1.2.4, Computed on Test on 2022-02-22 at 12:00:00 took 1s",
],
}
self.assertSequenceEqual(data, expected)
for key, value in expected.items():
self.assertEqual(data[key][0], value[0])
self.assertDictEqual(data[key][1], value[1])
def test_duration_message(self):
expected = ["47.234s", "5.421m", "1.177h"]

View File

@@ -75,7 +75,17 @@ class ReportTest(TestBase):
report = ReportBest("accuracy", "STree", best=False, grid=True)
with patch(self.output, new=StringIO()) as stdout:
report.report()
self.check_output_file(stdout, "report_grid")
file_name = "report_grid.test"
with open(os.path.join(self.test_files, file_name)) as f:
expected = f.read().splitlines()
output_text = stdout.getvalue().splitlines()
# Compare replacing STree version
for line, index in zip(expected, range(len(expected))):
if "1.2.4" in line:
# replace STree version
line = self.replace_STree_version(line, output_text, index)
self.assertEqual(line, output_text[index])
def test_report_best_both(self):
report = ReportBest("accuracy", "STree", best=True, grid=True)

View File

@@ -50,6 +50,11 @@ class TestBase(unittest.TestCase):
expected = f.read()
self.assertEqual(output.getvalue(), expected)
@staticmethod
def replace_STree_version(line, output, index):
idx = line.find("1.2.4")
return line.replace("1.2.4", output[index][idx : idx + 5])
def check_file_file(self, computed_file, expected_file):
with open(computed_file) as f:
computed = f.read()

View File

@@ -6,7 +6,7 @@
"kernel": "liblinear",
"multiclass_strategy": "ovr"
},
"v. 1.2.4, Computed on Test on 2022-02-22 at 12:00:00 took 1s"
"v. 1.3.0, Computed on Test on 2022-02-22 at 12:00:00 took 1s"
],
"balloons": [
0.625,
@@ -15,6 +15,6 @@
"kernel": "linear",
"multiclass_strategy": "ovr"
},
"v. 1.2.4, Computed on Test on 2022-02-22 at 12:00:00 took 1s"
"v. 1.3.0, Computed on Test on 2022-02-22 at 12:00:00 took 1s"
]
}

View File

@@ -55,7 +55,16 @@ class BeReportTest(TestBase):
"be_report", ["-s", "accuracy", "-m", "STree", "-g", "1"]
)
self.assertEqual(stderr.getvalue(), "")
self.check_output_file(stdout, "report_grid")
file_name = "report_grid.test"
with open(os.path.join(self.test_files, file_name)) as f:
expected = f.read().splitlines()
output_text = stdout.getvalue().splitlines()
# Compare replacing STree version
for line, index in zip(expected, range(len(expected))):
if "1.2.4" in line:
# replace STree version
line = self.replace_STree_version(line, output_text, index)
self.assertEqual(line, output_text[index])
def test_be_report_best_both(self):
stdout, stderr = self.execute_script(