mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-24 03:45:56 +00:00
Add gridsearch experiment
This commit is contained in:
@@ -7,7 +7,12 @@ from datetime import datetime
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.model_selection import StratifiedKFold, KFold, cross_validate
|
||||
from sklearn.model_selection import (
|
||||
StratifiedKFold,
|
||||
KFold,
|
||||
GridSearchCV,
|
||||
cross_validate,
|
||||
)
|
||||
from Utils import Folders, Files
|
||||
from Models import Models
|
||||
|
||||
@@ -288,3 +293,119 @@ class Experiment:
|
||||
self._output_results()
|
||||
if self.progress_bar:
|
||||
print(f"Results in {self.output_file}")
|
||||
|
||||
|
||||
class GridSearch:
|
||||
def __init__(
|
||||
self,
|
||||
score_name,
|
||||
model_name,
|
||||
stratified,
|
||||
datasets,
|
||||
platform,
|
||||
progress_bar=True,
|
||||
folds=5,
|
||||
):
|
||||
today = datetime.now()
|
||||
self.time = today.strftime("%H:%M:%S")
|
||||
self.date = today.strftime("%Y-%m-%d")
|
||||
self.output_file = os.path.join(
|
||||
Folders.results,
|
||||
Files.grid_output(
|
||||
score_name,
|
||||
model_name,
|
||||
),
|
||||
)
|
||||
self.score_name = score_name
|
||||
self.model_name = model_name
|
||||
self.stratified = stratified == "1"
|
||||
self.stratified_class = StratifiedKFold if self.stratified else KFold
|
||||
self.datasets = datasets
|
||||
self.progress_bar = progress_bar
|
||||
self.folds = folds
|
||||
self.platform = platform
|
||||
self.random_seeds = Randomized.seeds
|
||||
self.grid_file = os.path.join(
|
||||
Folders.results, Files.grid_input(score_name, model_name)
|
||||
)
|
||||
with open(self.grid_file) as f:
|
||||
self.grid = json.load(f)
|
||||
self.duration = 0
|
||||
self._init_data()
|
||||
|
||||
def _init_data(self):
|
||||
# if result file not exist initialize it
|
||||
try:
|
||||
with open(self.output_file, "r") as f:
|
||||
self.results = json.load(f)
|
||||
except FileNotFoundError:
|
||||
# init file
|
||||
output = {}
|
||||
data = Datasets()
|
||||
for item in data:
|
||||
output[item] = [0.0, {}, ""]
|
||||
with open(self.output_file, "w") as f:
|
||||
json.dump(output, f)
|
||||
self.results = output
|
||||
|
||||
def _save_results(self):
|
||||
with open(self.output_file, "r") as f:
|
||||
data = json.load(f)
|
||||
for item in self.datasets:
|
||||
data[item] = self.results[item]
|
||||
with open(self.output_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def _store_result(self, name, grid, duration):
|
||||
d_message = f"{duration:.3f} s"
|
||||
if duration > 3600:
|
||||
d_message = f"{duration / 3600:.3f} h"
|
||||
elif duration > 60:
|
||||
d_message = f"{duration / 60:.3f} min"
|
||||
message = (
|
||||
f"v. {self.version}, Computed on {self.platform} on "
|
||||
f"{self.date} at {self.time} "
|
||||
f"took {d_message}"
|
||||
)
|
||||
score = grid.best_score_
|
||||
hyperparameters = grid.best_params_
|
||||
self.results[name] = [score, hyperparameters, message]
|
||||
|
||||
def do_gridsearch(self):
|
||||
now = time.time()
|
||||
loop = tqdm(
|
||||
list(self.datasets),
|
||||
position=0,
|
||||
disable=not self.progress_bar,
|
||||
)
|
||||
for name in loop:
|
||||
loop.set_description(f"{name:30s}")
|
||||
X, y = self.datasets.load(name)
|
||||
result = self._n_fold_gridsearch(X, y)
|
||||
self._store_result(name, result, time.time() - now)
|
||||
self._save_results()
|
||||
|
||||
def _n_fold_gridsearch(self, X, y):
|
||||
kfold = self.stratified_class(
|
||||
shuffle=True,
|
||||
random_state=self.random_seeds[0],
|
||||
n_splits=self.folds,
|
||||
)
|
||||
clf = Models.get_model(self.model_name)
|
||||
self.version = clf.version() if hasattr(clf, "version") else "-"
|
||||
self._num_warnings = 0
|
||||
warnings.warn = self._warn
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
grid = GridSearchCV(
|
||||
estimator=clf,
|
||||
cv=kfold,
|
||||
param_grid=self.grid,
|
||||
scoring=self.score_name,
|
||||
n_jobs=-1,
|
||||
)
|
||||
grid.fit(X, y)
|
||||
return grid
|
||||
|
||||
def _warn(self, *args, **kwargs) -> None:
|
||||
self._num_warnings += 1
|
||||
|
Reference in New Issue
Block a user