Begin be_grid tests

This commit is contained in:
2022-05-07 23:28:38 +02:00
parent b8c4e30714
commit af95e9c6bc
3 changed files with 140 additions and 0 deletions

View File

@@ -15,6 +15,7 @@ from .scripts.Be_Pair_check_test import BePairCheckTest
from .scripts.Be_List_test import BeListTest from .scripts.Be_List_test import BeListTest
from .scripts.Be_Report_test import BeReportTest from .scripts.Be_Report_test import BeReportTest
from .scripts.Be_Summary_test import BeSummaryTest from .scripts.Be_Summary_test import BeSummaryTest
from .scripts.Be_Grid_test import BeGridTest
all = [ all = [
"UtilTest", "UtilTest",
@@ -34,4 +35,5 @@ all = [
"BeListTest", "BeListTest",
"BeReportTest", "BeReportTest",
"BeSummaryTest", "BeSummaryTest",
"BeGridTest",
] ]

View File

@@ -0,0 +1,33 @@
import os
from ...Utils import Folders
from ..TestBase import TestBase
class BeGridTest(TestBase):
def setUp(self):
self.prepare_scripts_env()
def tearDown(self) -> None:
self.remove_files(["grid_input_f1-macro_STree.json"], Folders.results)
return super().tearDown()
def test_be_build_grid(self):
stdout, stderr = self.execute_script(
"be_build_grid", ["-m", "STree", "-s", "f1-macro"]
)
self.assertEqual(stderr.getvalue(), "")
self.assertEqual(
stdout.getvalue(),
"Generated grid input file to results/grid_input_f1-macro_STree."
"json\n",
)
name = stdout.getvalue().split("/")[1].replace("\n", "")
file_name = os.path.join(Folders.results, name)
self.check_file_file(file_name, "be_build_grid")
def test_be_grid_(self):
stdout, stderr = self.execute_script(
"be_grid",
["-m", "STree", "-s", "accuracy", "--n_folds", 2, "-q", "1"],
)
self.assertEqual(stderr.getvalue(), "")

View File

@@ -0,0 +1,105 @@
[
{
"n_jobs": [
-1
],
"n_estimators": [
100
],
"base_estimator__C": [
1.0
],
"base_estimator__kernel": [
"linear"
],
"base_estimator__multiclass_strategy": [
"ovo"
]
},
{
"n_jobs": [
-1
],
"n_estimators": [
100
],
"base_estimator__C": [
0.001,
0.0275,
0.05,
0.08,
0.2,
0.25,
0.95,
1.0,
1.75,
7,
10000.0
],
"base_estimator__kernel": [
"liblinear"
],
"base_estimator__multiclass_strategy": [
"ovr"
]
},
{
"n_jobs": [
-1
],
"n_estimators": [
100
],
"base_estimator__C": [
0.05,
1.0,
1.05,
2,
2.8,
2.83,
5,
7,
57,
10000.0
],
"base_estimator__gamma": [
0.001,
0.1,
0.14,
10.0,
"auto",
"scale"
],
"base_estimator__kernel": [
"rbf"
],
"base_estimator__multiclass_strategy": [
"ovr"
]
},
{
"n_jobs": [
-1
],
"n_estimators": [
100
],
"base_estimator__C": [
0.05,
0.2,
1.0,
8.25
],
"base_estimator__gamma": [
0.1,
"scale"
],
"base_estimator__kernel": [
"poly"
],
"base_estimator__multiclass_strategy": [
"ovo",
"ovr"
]
}
]