Complete be_main tests

This commit is contained in:
2022-05-09 00:23:18 +02:00
parent 09b2ede836
commit b3bc2fbd2f
3 changed files with 104 additions and 34 deletions

View File

@@ -16,64 +16,102 @@ class BeMainTest(TestBase):
self.remove_files(self.files, ".") self.remove_files(self.files, ".")
return super().tearDown() return super().tearDown()
def check_output_lines(self, stdout, file_name, lines_to_compare):
with open(os.path.join(self.test_files, f"{file_name}.test")) as f:
expected = f.read()
computed_data = stdout.getvalue().splitlines()
n_line = 0
# compare only report lines without date, time, duration...
for expected, computed in zip(expected.splitlines(), computed_data):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def test_be_benchmark_dataset(self): def test_be_benchmark_dataset(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",
["-m", "STree", "-d", "balloons", "--title", "test"], ["-m", "STree", "-d", "balloons", "--title", "test"],
) )
with open(os.path.join(self.test_files, "be_main_dataset.test")) as f: self.check_output_lines(
expected = f.read() stdout=stdout,
n_line = 0 file_name="be_main_dataset",
# compare only report lines without date, time, duration... lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13],
lines_to_compare = [0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13] )
for expected, computed in zip(
expected.splitlines(), stdout.getvalue().splitlines()
):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def test_be_benchmark_complete(self): def test_be_benchmark_complete(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",
["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"], ["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"],
) )
with open(os.path.join(self.test_files, "be_main_complete.test")) as f:
expected = f.read()
n_line = 0
# keep the report name to delete it after # keep the report name to delete it after
self.files.append(stdout.getvalue().splitlines()[-1].split("in ")[1]) report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
# compare only report lines without date, time, duration... self.files.append(report_name)
lines_to_compare = [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14] self.check_output_lines(
for expected, computed in zip( stdout, "be_main_complete", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
expected.splitlines(), stdout.getvalue().splitlines() )
):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def test_be_benchmark_no_report(self): def test_be_benchmark_no_report(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",
["-s", self.score, "-m", "STree", "--title", "test"], ["-s", self.score, "-m", "STree", "--title", "test"],
) )
with open(os.path.join(self.test_files, "be_main_complete.test")) as f:
expected = f.read()
# keep the report name to delete it after # keep the report name to delete it after
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1] report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
self.files.append(report_name) self.files.append(report_name)
report = Report(file_name=report_name) report = Report(file_name=report_name)
with patch(self.output, new=StringIO()) as stdout: with patch(self.output, new=StringIO()) as stdout:
report.report() report.report()
# compare only report lines without date, time, duration... self.check_output_lines(
lines_to_compare = [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14] stdout,
n_line = 0 "be_main_complete",
for expected, computed in zip( [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14],
expected.splitlines(), stdout.getvalue().splitlines() )
):
if n_line in lines_to_compare: def test_be_benchmark_best_params(self):
self.assertEqual(computed, expected, n_line) stdout, _ = self.execute_script(
n_line += 1 "be_main",
[
"-s",
self.score,
"-m",
"STree",
"--title",
"test",
"-f",
"1",
"-r",
"1",
],
)
# keep the report name to delete it after
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
self.files.append(report_name)
self.check_output_lines(
stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
)
def test_be_benchmark_grid_params(self):
stdout, _ = self.execute_script(
"be_main",
[
"-s",
self.score,
"-m",
"STree",
"--title",
"test",
"-g",
"1",
"-r",
"1",
],
)
# keep the report name to delete it after
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
self.files.append(report_name)
self.check_output_lines(
stdout, "be_main_grid", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
)
def test_be_benchmark_no_data(self): def test_be_benchmark_no_data(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(

View File

@@ -0,0 +1,16 @@
***********************************************************************************************************************
* Report STree ver. 1.2.4 with 5 Folds cross validation and 10 random seeds. 2022-05-09 00:15:25 *
* test *
* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
* Execution took 0.80 seconds, 0.00 hours, on iMac27 *
* Score is accuracy *
***********************************************************************************************************************
Dataset Samp Feat. Cls Nodes Leaves Depth Score Time Hyperparameters
============================== ===== ===== === ======= ======= ======= =============== ================ ===============
balance-scale 625 4 3 23.32 12.16 6.44 0.840160±0.0304 0.013745±0.0019 {'splitter': 'best', 'max_features': 'auto'}
balloons 16 4 2 3.00 2.00 2.00 0.860000±0.2850 0.000388±0.0000 {'C': 7, 'gamma': 0.1, 'kernel': 'rbf', 'max_iter': 10000.0, 'multiclass_strategy': 'ovr'}
***********************************************************************************************************************
* Accuracy compared to stree_default (liblinear-ovr) .: 0.0422 *
***********************************************************************************************************************
Results in results/results_accuracy_STree_iMac27_2022-05-09_00:15:25_0.json

View File

@@ -0,0 +1,16 @@
***********************************************************************************************************************
* Report STree ver. 1.2.4 with 5 Folds cross validation and 10 random seeds. 2022-05-09 00:21:06 *
* test *
* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
* Execution took 0.89 seconds, 0.00 hours, on iMac27 *
* Score is accuracy *
***********************************************************************************************************************
Dataset Samp Feat. Cls Nodes Leaves Depth Score Time Hyperparameters
============================== ===== ===== === ======= ======= ======= =============== ================ ===============
balance-scale 625 4 3 26.12 13.56 7.94 0.910720±0.0249 0.015852±0.0027 {'C': 1.0, 'kernel': 'liblinear', 'multiclass_strategy': 'ovr'}
balloons 16 4 2 4.64 2.82 2.66 0.663333±0.3009 0.000640±0.0001 {'C': 1.0, 'kernel': 'linear', 'multiclass_strategy': 'ovr'}
***********************************************************************************************************************
* Accuracy compared to stree_default (liblinear-ovr) .: 0.0391 *
***********************************************************************************************************************
Results in results/results_accuracy_STree_iMac27_2022-05-09_00:21:06_0.json