mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-15 23:45:54 +00:00
Fix error in tests with STree fixed version
This commit is contained in:
@@ -77,7 +77,9 @@ class GridSearchTest(TestBase):
|
||||
"v. 1.2.4, Computed on Test on 2022-02-22 at 12:00:00 took 1s",
|
||||
],
|
||||
}
|
||||
self.assertSequenceEqual(data, expected)
|
||||
for key, value in expected.items():
|
||||
self.assertEqual(data[key][0], value[0])
|
||||
self.assertDictEqual(data[key][1], value[1])
|
||||
|
||||
def test_duration_message(self):
|
||||
expected = ["47.234s", "5.421m", "1.177h"]
|
||||
|
@@ -75,7 +75,17 @@ class ReportTest(TestBase):
|
||||
report = ReportBest("accuracy", "STree", best=False, grid=True)
|
||||
with patch(self.output, new=StringIO()) as stdout:
|
||||
report.report()
|
||||
self.check_output_file(stdout, "report_grid")
|
||||
file_name = "report_grid.test"
|
||||
with open(os.path.join(self.test_files, file_name)) as f:
|
||||
expected = f.read().splitlines()
|
||||
output_text = stdout.getvalue().splitlines()
|
||||
# Compare replacing STree version
|
||||
for line, index in zip(expected, range(len(expected))):
|
||||
if "1.2.4" in line:
|
||||
# replace STree version
|
||||
line = self.replace_STree_version(line, output_text, index)
|
||||
|
||||
self.assertEqual(line, output_text[index])
|
||||
|
||||
def test_report_best_both(self):
|
||||
report = ReportBest("accuracy", "STree", best=True, grid=True)
|
||||
|
@@ -50,6 +50,11 @@ class TestBase(unittest.TestCase):
|
||||
expected = f.read()
|
||||
self.assertEqual(output.getvalue(), expected)
|
||||
|
||||
@staticmethod
|
||||
def replace_STree_version(line, output, index):
|
||||
idx = line.find("1.2.4")
|
||||
return line.replace("1.2.4", output[index][idx : idx + 5])
|
||||
|
||||
def check_file_file(self, computed_file, expected_file):
|
||||
with open(computed_file) as f:
|
||||
computed = f.read()
|
||||
|
@@ -6,7 +6,7 @@
|
||||
"kernel": "liblinear",
|
||||
"multiclass_strategy": "ovr"
|
||||
},
|
||||
"v. 1.2.4, Computed on Test on 2022-02-22 at 12:00:00 took 1s"
|
||||
"v. 1.3.0, Computed on Test on 2022-02-22 at 12:00:00 took 1s"
|
||||
],
|
||||
"balloons": [
|
||||
0.625,
|
||||
@@ -15,6 +15,6 @@
|
||||
"kernel": "linear",
|
||||
"multiclass_strategy": "ovr"
|
||||
},
|
||||
"v. 1.2.4, Computed on Test on 2022-02-22 at 12:00:00 took 1s"
|
||||
"v. 1.3.0, Computed on Test on 2022-02-22 at 12:00:00 took 1s"
|
||||
]
|
||||
}
|
@@ -55,7 +55,16 @@ class BeReportTest(TestBase):
|
||||
"be_report", ["-s", "accuracy", "-m", "STree", "-g", "1"]
|
||||
)
|
||||
self.assertEqual(stderr.getvalue(), "")
|
||||
self.check_output_file(stdout, "report_grid")
|
||||
file_name = "report_grid.test"
|
||||
with open(os.path.join(self.test_files, file_name)) as f:
|
||||
expected = f.read().splitlines()
|
||||
output_text = stdout.getvalue().splitlines()
|
||||
# Compare replacing STree version
|
||||
for line, index in zip(expected, range(len(expected))):
|
||||
if "1.2.4" in line:
|
||||
# replace STree version
|
||||
line = self.replace_STree_version(line, output_text, index)
|
||||
self.assertEqual(line, output_text[index])
|
||||
|
||||
def test_be_report_best_both(self):
|
||||
stdout, stderr = self.execute_script(
|
||||
|
Reference in New Issue
Block a user