Complete be_main tests

This commit is contained in:
2022-05-09 00:23:18 +02:00
parent 09b2ede836
commit b3bc2fbd2f
3 changed files with 104 additions and 34 deletions

View File

@@ -16,64 +16,102 @@ class BeMainTest(TestBase):
self.remove_files(self.files, ".")
return super().tearDown()
def check_output_lines(self, stdout, file_name, lines_to_compare):
with open(os.path.join(self.test_files, f"{file_name}.test")) as f:
expected = f.read()
computed_data = stdout.getvalue().splitlines()
n_line = 0
# compare only report lines without date, time, duration...
for expected, computed in zip(expected.splitlines(), computed_data):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def test_be_benchmark_dataset(self):
stdout, _ = self.execute_script(
"be_main",
["-m", "STree", "-d", "balloons", "--title", "test"],
)
with open(os.path.join(self.test_files, "be_main_dataset.test")) as f:
expected = f.read()
n_line = 0
# compare only report lines without date, time, duration...
lines_to_compare = [0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13]
for expected, computed in zip(
expected.splitlines(), stdout.getvalue().splitlines()
):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
self.check_output_lines(
stdout=stdout,
file_name="be_main_dataset",
lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13],
)
def test_be_benchmark_complete(self):
stdout, _ = self.execute_script(
"be_main",
["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"],
)
with open(os.path.join(self.test_files, "be_main_complete.test")) as f:
expected = f.read()
n_line = 0
# keep the report name to delete it after
self.files.append(stdout.getvalue().splitlines()[-1].split("in ")[1])
# compare only report lines without date, time, duration...
lines_to_compare = [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
for expected, computed in zip(
expected.splitlines(), stdout.getvalue().splitlines()
):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
self.files.append(report_name)
self.check_output_lines(
stdout, "be_main_complete", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
)
def test_be_benchmark_no_report(self):
stdout, _ = self.execute_script(
"be_main",
["-s", self.score, "-m", "STree", "--title", "test"],
)
with open(os.path.join(self.test_files, "be_main_complete.test")) as f:
expected = f.read()
# keep the report name to delete it after
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
self.files.append(report_name)
report = Report(file_name=report_name)
with patch(self.output, new=StringIO()) as stdout:
report.report()
# compare only report lines without date, time, duration...
lines_to_compare = [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
n_line = 0
for expected, computed in zip(
expected.splitlines(), stdout.getvalue().splitlines()
):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
self.check_output_lines(
stdout,
"be_main_complete",
[0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14],
)
def test_be_benchmark_best_params(self):
stdout, _ = self.execute_script(
"be_main",
[
"-s",
self.score,
"-m",
"STree",
"--title",
"test",
"-f",
"1",
"-r",
"1",
],
)
# keep the report name to delete it after
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
self.files.append(report_name)
self.check_output_lines(
stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
)
def test_be_benchmark_grid_params(self):
stdout, _ = self.execute_script(
"be_main",
[
"-s",
self.score,
"-m",
"STree",
"--title",
"test",
"-g",
"1",
"-r",
"1",
],
)
# keep the report name to delete it after
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
self.files.append(report_name)
self.check_output_lines(
stdout, "be_main_grid", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
)
def test_be_benchmark_no_data(self):
stdout, _ = self.execute_script(

View File

@@ -0,0 +1,16 @@
***********************************************************************************************************************
* Report STree ver. 1.2.4 with 5 Folds cross validation and 10 random seeds. 2022-05-09 00:15:25 *
* test *
* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
* Execution took 0.80 seconds, 0.00 hours, on iMac27 *
* Score is accuracy *
***********************************************************************************************************************
Dataset Samp Feat. Cls Nodes Leaves Depth Score Time Hyperparameters
============================== ===== ===== === ======= ======= ======= =============== ================ ===============
balance-scale 625 4 3 23.32 12.16 6.44 0.840160±0.0304 0.013745±0.0019 {'splitter': 'best', 'max_features': 'auto'}
balloons 16 4 2 3.00 2.00 2.00 0.860000±0.2850 0.000388±0.0000 {'C': 7, 'gamma': 0.1, 'kernel': 'rbf', 'max_iter': 10000.0, 'multiclass_strategy': 'ovr'}
***********************************************************************************************************************
* Accuracy compared to stree_default (liblinear-ovr) .: 0.0422 *
***********************************************************************************************************************
Results in results/results_accuracy_STree_iMac27_2022-05-09_00:15:25_0.json

View File

@@ -0,0 +1,16 @@
***********************************************************************************************************************
* Report STree ver. 1.2.4 with 5 Folds cross validation and 10 random seeds. 2022-05-09 00:21:06 *
* test *
* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
* Execution took 0.89 seconds, 0.00 hours, on iMac27 *
* Score is accuracy *
***********************************************************************************************************************
Dataset Samp Feat. Cls Nodes Leaves Depth Score Time Hyperparameters
============================== ===== ===== === ======= ======= ======= =============== ================ ===============
balance-scale 625 4 3 26.12 13.56 7.94 0.910720±0.0249 0.015852±0.0027 {'C': 1.0, 'kernel': 'liblinear', 'multiclass_strategy': 'ovr'}
balloons 16 4 2 4.64 2.82 2.66 0.663333±0.3009 0.000640±0.0001 {'C': 1.0, 'kernel': 'linear', 'multiclass_strategy': 'ovr'}
***********************************************************************************************************************
* Accuracy compared to stree_default (liblinear-ovr) .: 0.0391 *
***********************************************************************************************************************
Results in results/results_accuracy_STree_iMac27_2022-05-09_00:21:06_0.json