Add gridsearch experiment

This commit is contained in:
2022-03-09 13:44:49 +01:00
parent 3d0ab041ee
commit c3e05f7d27
5 changed files with 341 additions and 2 deletions

View File

@@ -7,7 +7,12 @@ from datetime import datetime
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold, KFold, cross_validate
from sklearn.model_selection import (
StratifiedKFold,
KFold,
GridSearchCV,
cross_validate,
)
from Utils import Folders, Files
from Models import Models
@@ -288,3 +293,119 @@ class Experiment:
self._output_results()
if self.progress_bar:
print(f"Results in {self.output_file}")
class GridSearch:
def __init__(
self,
score_name,
model_name,
stratified,
datasets,
platform,
progress_bar=True,
folds=5,
):
today = datetime.now()
self.time = today.strftime("%H:%M:%S")
self.date = today.strftime("%Y-%m-%d")
self.output_file = os.path.join(
Folders.results,
Files.grid_output(
score_name,
model_name,
),
)
self.score_name = score_name
self.model_name = model_name
self.stratified = stratified == "1"
self.stratified_class = StratifiedKFold if self.stratified else KFold
self.datasets = datasets
self.progress_bar = progress_bar
self.folds = folds
self.platform = platform
self.random_seeds = Randomized.seeds
self.grid_file = os.path.join(
Folders.results, Files.grid_input(score_name, model_name)
)
with open(self.grid_file) as f:
self.grid = json.load(f)
self.duration = 0
self._init_data()
def _init_data(self):
# if result file not exist initialize it
try:
with open(self.output_file, "r") as f:
self.results = json.load(f)
except FileNotFoundError:
# init file
output = {}
data = Datasets()
for item in data:
output[item] = [0.0, {}, ""]
with open(self.output_file, "w") as f:
json.dump(output, f)
self.results = output
def _save_results(self):
with open(self.output_file, "r") as f:
data = json.load(f)
for item in self.datasets:
data[item] = self.results[item]
with open(self.output_file, "w") as f:
json.dump(data, f)
def _store_result(self, name, grid, duration):
d_message = f"{duration:.3f} s"
if duration > 3600:
d_message = f"{duration / 3600:.3f} h"
elif duration > 60:
d_message = f"{duration / 60:.3f} min"
message = (
f"v. {self.version}, Computed on {self.platform} on "
f"{self.date} at {self.time} "
f"took {d_message}"
)
score = grid.best_score_
hyperparameters = grid.best_params_
self.results[name] = [score, hyperparameters, message]
def do_gridsearch(self):
now = time.time()
loop = tqdm(
list(self.datasets),
position=0,
disable=not self.progress_bar,
)
for name in loop:
loop.set_description(f"{name:30s}")
X, y = self.datasets.load(name)
result = self._n_fold_gridsearch(X, y)
self._store_result(name, result, time.time() - now)
self._save_results()
def _n_fold_gridsearch(self, X, y):
kfold = self.stratified_class(
shuffle=True,
random_state=self.random_seeds[0],
n_splits=self.folds,
)
clf = Models.get_model(self.model_name)
self.version = clf.version() if hasattr(clf, "version") else "-"
self._num_warnings = 0
warnings.warn = self._warn
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
grid = GridSearchCV(
estimator=clf,
cv=kfold,
param_grid=self.grid,
scoring=self.score_name,
n_jobs=-1,
)
grid.fit(X, y)
return grid
def _warn(self, *args, **kwargs) -> None:
self._num_warnings += 1