Add 10 random seeds run in crossval

Add testwodt comparison
This commit is contained in:
2021-03-10 16:42:18 +01:00
parent d4cfe77b18
commit e791d2edf5
5 changed files with 127 additions and 15 deletions

View File

@@ -75,16 +75,25 @@ class Experiment:
model.set_params(**parameters)
self._num_warnings = 0
warnings.warn = self._warn
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# Also affect subprocesses
os.environ["PYTHONWARNINGS"] = "ignore"
results = cross_validate(
model, X, y, return_train_score=True, n_jobs=self._threads
)
# Execute with 10 different seeds to ease the random effect
total = {}
outcomes = ["fit_time", "score_time", "train_score", "test_score"]
for item in outcomes:
total[item] = []
for random_state in [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]:
model.set_params(**{"random_state": random_state})
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# Also affect subprocesses
os.environ["PYTHONWARNINGS"] = "ignore"
results = cross_validate(
model, X, y, return_train_score=True, n_jobs=self._threads
)
for item in outcomes:
total[item].append(results[item])
outcomes = Outcomes(host=self._host, model=self._model_name)
parameters = json.dumps(parameters, sort_keys=True)
outcomes.store(dataset, normalize, standardize, parameters, results)
outcomes.store(dataset, normalize, standardize, parameters, total)
if self._num_warnings > 0:
print(f"{self._num_warnings} warnings have happend")