add kernel hyperparameter subset in gridsearch

This commit is contained in:
2020-11-29 00:00:27 +01:00
parent b14edf4303
commit 2decec05fb
4 changed files with 98 additions and 36 deletions

View File

@@ -49,35 +49,46 @@ class ModelStree(ModelBase):
gamma = [1e-1, 1, 1e1]
max_features = [None, "auto"]
split_criteria = ["impurity", "max_samples"]
self._linear = {
"random_state": [self._random_state],
"C": C,
"max_iter": max_iter,
"split_criteria": split_criteria,
"max_features": max_features,
}
self._poly = {
"random_state": [self._random_state],
"kernel": ["rbf"],
"C": C,
"gamma": gamma,
"max_iter": max_iter,
"split_criteria": split_criteria,
"max_features": max_features,
}
self._rbf = {
"random_state": [self._random_state],
"kernel": ["poly"],
"degree": [3, 5],
"C": C,
"gamma": gamma,
"max_iter": max_iter,
"split_criteria": split_criteria,
"max_features": max_features,
}
self._param_grid = [
{
"random_state": [self._random_state],
"C": C,
"max_iter": max_iter,
"split_criteria": split_criteria,
"max_features": max_features,
},
{
"random_state": [self._random_state],
"kernel": ["rbf"],
"C": C,
"gamma": gamma,
"max_iter": max_iter,
"split_criteria": split_criteria,
"max_features": max_features,
},
{
"random_state": [self._random_state],
"kernel": ["poly"],
"degree": [3, 5],
"C": C,
"gamma": gamma,
"max_iter": max_iter,
"split_criteria": split_criteria,
"max_features": max_features,
},
self._linear,
self._poly,
self._rbf,
]
def select_params(self, kernel: str) -> None:
if kernel == "linear":
self._param_grid = [self._linear]
elif kernel == "poly":
self._param_grid = [self._poly]
else:
self._param_grid = [self._rbf]
class ModelSVC(ModelBase):
def __init__(self, random_state: Optional[int] = None) -> None: