mirror of
https://github.com/Doctorado-ML/Stree_datasets.git
synced 2025-08-15 23:46:03 +00:00
get mysql stored params first in cross validation
This commit is contained in:
@@ -6,7 +6,7 @@ import warnings
|
||||
from sklearn.model_selection import GridSearchCV, cross_validate
|
||||
|
||||
from . import Models
|
||||
from .Database import Hyperparameters, Outcomes
|
||||
from .Database import Hyperparameters, Outcomes, MySQL
|
||||
from .Sets import Datasets
|
||||
|
||||
|
||||
@@ -20,12 +20,8 @@ class Experiment:
|
||||
kernel: str,
|
||||
) -> None:
|
||||
self._random_state = random_state
|
||||
self._model_name = model
|
||||
self._set_model(model)
|
||||
self._set_of_files = set_of_files
|
||||
self._type = getattr(
|
||||
Models,
|
||||
f"Model{model[0].upper() + model[1:]}",
|
||||
)
|
||||
self._clf = self._type(random_state=self._random_state)
|
||||
self._host = host
|
||||
# used in gridsearch with ensembles to take best hyperparams of
|
||||
@@ -36,15 +32,35 @@ class Experiment:
|
||||
def set_base_params(self, base_params: str) -> None:
|
||||
self._base_params = base_params
|
||||
|
||||
def _set_model(self, model_name: str) -> None:
|
||||
self._model_name = model_name
|
||||
self._type = getattr(
|
||||
Models,
|
||||
f"Model{model_name[0].upper() + model_name[1:]}",
|
||||
)
|
||||
|
||||
def cross_validation(self, dataset: str) -> None:
|
||||
hyperparams = Hyperparameters(host=self._host, model=self._model_name)
|
||||
try:
|
||||
parameters, normalize, standardize = hyperparams.get_params(
|
||||
dataset
|
||||
)
|
||||
except ValueError:
|
||||
print(f"*** {dataset} not trained")
|
||||
return
|
||||
self._clf = self._type(random_state=self._random_state)
|
||||
model = self._clf.get_model()
|
||||
hyperparams = MySQL()
|
||||
hyperparams.get_connection()
|
||||
record = hyperparams.find_best(dataset, self._model_name)
|
||||
hyperparams.close()
|
||||
if record is None:
|
||||
try:
|
||||
hyperparams = Hyperparameters(
|
||||
host=self._host, model=self._model_name
|
||||
)
|
||||
parameters, normalize, standardize = hyperparams.get_params(
|
||||
dataset
|
||||
)
|
||||
except ValueError:
|
||||
print(f"*** {dataset} not trained")
|
||||
return
|
||||
else:
|
||||
normalize = record[6]
|
||||
standardize = record[7]
|
||||
parameters = record[8]
|
||||
datasets = Datasets(
|
||||
normalize=normalize,
|
||||
standardize=standardize,
|
||||
@@ -54,7 +70,7 @@ class Experiment:
|
||||
X, y = datasets.load(dataset)
|
||||
# init cross validation object just in case consecutive experiments
|
||||
self._clf = self._type(random_state=self._random_state)
|
||||
model = self._clf.get_model().set_params(**parameters)
|
||||
model.set_params(**parameters)
|
||||
self._num_warnings = 0
|
||||
warnings.warn = self._warn
|
||||
with warnings.catch_warnings():
|
||||
|
Reference in New Issue
Block a user