Continue be_grid tests

This commit is contained in:
2022-05-07 23:33:35 +02:00
parent af95e9c6bc
commit 986341723c
4 changed files with 27 additions and 4 deletions

View File

@@ -450,7 +450,7 @@ class GridSearch:
random_state=self.random_seeds[0],
n_splits=self.folds,
)
clf = Models.get_model(self.model_name)
clf = Models.get_model(self.model_name, self.random_seeds[0])
self.version = clf.version() if hasattr(clf, "version") else "-"
self._num_warnings = 0
warnings.warn = self._warn
@@ -460,7 +460,7 @@ class GridSearch:
estimator=clf,
cv=kfold,
param_grid=self.grid,
scoring=self.score_name,
scoring=self.score_name.replace("-", "_"),
n_jobs=-1,
)
grid.fit(X, y)

View File

@@ -111,7 +111,6 @@ def main(args_test=None):
t2 = sorted([x for x in value if isinstance(x, str)])
results_tmp[new_key] = t1 + t2
output.append(results_tmp)
# save results
file_name = Files.grid_input(args.score, args.model)
file_output = os.path.join(Folders.results, file_name)

View File

@@ -21,7 +21,7 @@ class BeGridTest(TestBase):
"Generated grid input file to results/grid_input_f1-macro_STree."
"json\n",
)
name = stdout.getvalue().split("/")[1].replace("\n", "")
name = File.grid_input("f1-macro", "STree")
file_name = os.path.join(Folders.results, name)
self.check_file_file(file_name, "be_build_grid")
@@ -31,3 +31,7 @@ class BeGridTest(TestBase):
["-m", "STree", "-s", "accuracy", "--n_folds", 2, "-q", "1"],
)
self.assertEqual(stderr.getvalue(), "")
self.assertEqual(stdout.getvalue(), "")
name = File.grid_output("accuracy", "STree")
file_name = os.path.join(Folders.results, name)
self.check_file_file(file_name, "be_grid")

View File

@@ -0,0 +1,20 @@
{
"balance-scale": [
0.9119999999999999,
{
"C": 1.0,
"kernel": "liblinear",
"multiclass_strategy": "ovr"
},
"v. 1.2.4, Computed on iMac27 on 2022-05-07 at 23:29:25 took 0.962s"
],
"balloons": [
0.7,
{
"C": 1.0,
"kernel": "linear",
"multiclass_strategy": "ovr"
},
"v. 1.2.4, Computed on iMac27 on 2022-05-07 at 23:29:25 took 1.232s"
]
}