get mysql stored params first in cross validation

This commit is contained in:
2020-12-21 00:41:16 +01:00
parent dad717f0a3
commit 301a9fcfb1

View File

@@ -6,7 +6,7 @@ import warnings
from sklearn.model_selection import GridSearchCV, cross_validate
from . import Models
from .Database import Hyperparameters, Outcomes
from .Database import Hyperparameters, Outcomes, MySQL
from .Sets import Datasets
@@ -20,12 +20,8 @@ class Experiment:
kernel: str,
) -> None:
self._random_state = random_state
self._model_name = model
self._set_model(model)
self._set_of_files = set_of_files
self._type = getattr(
Models,
f"Model{model[0].upper() + model[1:]}",
)
self._clf = self._type(random_state=self._random_state)
self._host = host
# used in gridsearch with ensembles to take best hyperparams of
@@ -36,15 +32,35 @@ class Experiment:
def set_base_params(self, base_params: str) -> None:
self._base_params = base_params
def _set_model(self, model_name: str) -> None:
self._model_name = model_name
self._type = getattr(
Models,
f"Model{model_name[0].upper() + model_name[1:]}",
)
def cross_validation(self, dataset: str) -> None:
hyperparams = Hyperparameters(host=self._host, model=self._model_name)
try:
parameters, normalize, standardize = hyperparams.get_params(
dataset
)
except ValueError:
print(f"*** {dataset} not trained")
return
self._clf = self._type(random_state=self._random_state)
model = self._clf.get_model()
hyperparams = MySQL()
hyperparams.get_connection()
record = hyperparams.find_best(dataset, self._model_name)
hyperparams.close()
if record is None:
try:
hyperparams = Hyperparameters(
host=self._host, model=self._model_name
)
parameters, normalize, standardize = hyperparams.get_params(
dataset
)
except ValueError:
print(f"*** {dataset} not trained")
return
else:
normalize = record[6]
standardize = record[7]
parameters = record[8]
datasets = Datasets(
normalize=normalize,
standardize=standardize,
@@ -54,7 +70,7 @@ class Experiment:
X, y = datasets.load(dataset)
# init cross validation object just in case consecutive experiments
self._clf = self._type(random_state=self._random_state)
model = self._clf.get_model().set_params(**parameters)
model.set_params(**parameters)
self._num_warnings = 0
warnings.warn = self._warn
with warnings.catch_warnings():