From 301a9fcfb1b1551749b8c22a4f7f5ec927fead6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 21 Dec 2020 00:41:16 +0100 Subject: [PATCH] get mysql stored params first in cross validation --- experimentation/Experiments.py | 46 +++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/experimentation/Experiments.py b/experimentation/Experiments.py index aa41337..a419c98 100644 --- a/experimentation/Experiments.py +++ b/experimentation/Experiments.py @@ -6,7 +6,7 @@ import warnings from sklearn.model_selection import GridSearchCV, cross_validate from . import Models -from .Database import Hyperparameters, Outcomes +from .Database import Hyperparameters, Outcomes, MySQL from .Sets import Datasets @@ -20,12 +20,8 @@ class Experiment: kernel: str, ) -> None: self._random_state = random_state - self._model_name = model + self._set_model(model) self._set_of_files = set_of_files - self._type = getattr( - Models, - f"Model{model[0].upper() + model[1:]}", - ) self._clf = self._type(random_state=self._random_state) self._host = host # used in gridsearch with ensembles to take best hyperparams of @@ -36,15 +32,35 @@ class Experiment: def set_base_params(self, base_params: str) -> None: self._base_params = base_params + def _set_model(self, model_name: str) -> None: + self._model_name = model_name + self._type = getattr( + Models, + f"Model{model_name[0].upper() + model_name[1:]}", + ) + def cross_validation(self, dataset: str) -> None: - hyperparams = Hyperparameters(host=self._host, model=self._model_name) - try: - parameters, normalize, standardize = hyperparams.get_params( - dataset - ) - except ValueError: - print(f"*** {dataset} not trained") - return + self._clf = self._type(random_state=self._random_state) + model = self._clf.get_model() + hyperparams = MySQL() + hyperparams.get_connection() + record = hyperparams.find_best(dataset, self._model_name) + hyperparams.close() + if record is None: + try: + hyperparams = Hyperparameters( + host=self._host, model=self._model_name + ) + parameters, normalize, standardize = hyperparams.get_params( + dataset + ) + except ValueError: + print(f"*** {dataset} not trained") + return + else: + normalize = record[6] + standardize = record[7] + parameters = record[8] datasets = Datasets( normalize=normalize, standardize=standardize, @@ -54,7 +70,7 @@ class Experiment: X, y = datasets.load(dataset) # init cross validation object just in case consecutive experiments self._clf = self._type(random_state=self._random_state) - model = self._clf.get_model().set_params(**parameters) + model.set_params(**parameters) self._num_warnings = 0 warnings.warn = self._warn with warnings.catch_warnings():