Add grid file param to cross validation experiment

This commit is contained in:
2022-03-09 16:47:52 +01:00
parent 14f0886662
commit f031d67668
5 changed files with 110 additions and 88 deletions

View File

@@ -140,6 +140,7 @@ class Experiment:
datasets,
hyperparams_dict,
hyperparams_file,
grid_paramfile,
platform,
title,
progress_bar=True,
@@ -173,6 +174,12 @@ class Experiment:
self.hyperparameters_dict = hyper.load(
dictionary=dictionary,
)
elif grid_paramfile:
grid_file = os.path.join(
Folders.results, Files.grid_output(score_name, model_name)
)
with open(grid_file) as f:
self.hyperparameters_dict = json.load(f)
else:
self.hyperparameters_dict = hyper.fill(
dictionary=dictionary,
@@ -354,7 +361,7 @@ class GridSearch:
for item in self.datasets:
data[item] = self.results[item]
with open(self.output_file, "w") as f:
json.dump(data, f)
json.dump(data, f, indent=4)
def _store_result(self, name, grid, duration):
d_message = f"{duration:.3f} s"