Continue be_grid tests

This commit is contained in:
2022-05-08 00:12:52 +02:00
parent 986341723c
commit 5c4d5cb99e
4 changed files with 50 additions and 8 deletions

View File

@@ -26,4 +26,3 @@ def main(args_test=None):
job.do_gridsearch() job.do_gridsearch()
except FileNotFoundError: except FileNotFoundError:
print(f"** The grid input file [{job.grid_file}] could not be found") print(f"** The grid input file [{job.grid_file}] could not be found")
print("")

View File

@@ -80,7 +80,6 @@ class ModelTest(TestBase):
"GBC": ((15, 8, 3), 1.0), "GBC": ((15, 8, 3), 1.0),
} }
X, y = load_wine(return_X_y=True) X, y = load_wine(return_X_y=True)
print("")
for key, (value, score_expected) in test.items(): for key, (value, score_expected) in test.items():
clf = Models.get_model(key, random_state=1) clf = Models.get_model(key, random_state=1)
clf.fit(X, y) clf.fit(X, y)

View File

@@ -0,0 +1,6 @@
[
{
"C": [1.0, 5.0],
"kernel": ["linear", "rbf", "poly"]
}
]

View File

@@ -1,5 +1,6 @@
import os import os
from ...Utils import Folders import json
from ...Utils import Folders, Files
from ..TestBase import TestBase from ..TestBase import TestBase
@@ -8,7 +9,13 @@ class BeGridTest(TestBase):
self.prepare_scripts_env() self.prepare_scripts_env()
def tearDown(self) -> None: def tearDown(self) -> None:
self.remove_files(["grid_input_f1-macro_STree.json"], Folders.results) self.remove_files(
[
Files.grid_input("f1-macro", "STree"),
Files.grid_output("accuracy", "SVC"),
],
Folders.results,
)
return super().tearDown() return super().tearDown()
def test_be_build_grid(self): def test_be_build_grid(self):
@@ -21,17 +28,48 @@ class BeGridTest(TestBase):
"Generated grid input file to results/grid_input_f1-macro_STree." "Generated grid input file to results/grid_input_f1-macro_STree."
"json\n", "json\n",
) )
name = File.grid_input("f1-macro", "STree") name = Files.grid_input("f1-macro", "STree")
file_name = os.path.join(Folders.results, name) file_name = os.path.join(Folders.results, name)
self.check_file_file(file_name, "be_build_grid") self.check_file_file(file_name, "be_build_grid")
def test_be_grid_(self): def test_be_grid_(self):
stdout, stderr = self.execute_script( stdout, stderr = self.execute_script(
"be_grid", "be_grid",
["-m", "STree", "-s", "accuracy", "--n_folds", 2, "-q", "1"], ["-m", "SVC", "-s", "accuracy", "--n_folds", "2", "-q", "1"],
) )
self.assertEqual(stderr.getvalue(), "") self.assertEqual(stderr.getvalue(), "")
self.assertEqual(stdout.getvalue(), "") self.assertEqual(stdout.getvalue(), "")
name = File.grid_output("accuracy", "STree") name = Files.grid_output("accuracy", "SVC")
file_name = os.path.join(Folders.results, name) file_name = os.path.join(Folders.results, name)
self.check_file_file(file_name, "be_grid") with open(file_name, "r") as f:
computed_data = json.load(f)
expected_data = {
"balance-scale": [
0.9167895469812403,
{"C": 5.0, "kernel": "linear"},
"v. -, Computed on iMac27 on 2022-05-07 at 23:55:03 took",
],
"balloons": [
0.6875,
{"C": 5.0, "kernel": "rbf"},
"v. -, Computed on iMac27 on 2022-05-07 at 23:55:03 took",
],
}
for computed, expected in zip(computed_data, expected_data):
self.assertEqual(computed, expected)
for key, value in expected_data.items():
self.assertIn(key, computed_data)
self.assertEqual(computed_data[key][0], value[0])
self.assertEqual(computed_data[key][1], value[1])
def test_be_grid_no_input(self):
stdout, stderr = self.execute_script(
"be_grid",
["-m", "ODTE", "-s", "f1-weighted", "-q", "1"],
)
self.assertEqual(stderr.getvalue(), "")
grid_file = os.path.join(
Folders.results, Files.grid_input("f1-weighted", "ODTE")
)
expected = f"** The grid input file [{grid_file}] could not be found\n"
self.assertEqual(stdout.getvalue(), expected)