mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-17 16:35:54 +00:00
Add grid file param to cross validation experiment
This commit is contained in:
@@ -140,6 +140,7 @@ class Experiment:
|
||||
datasets,
|
||||
hyperparams_dict,
|
||||
hyperparams_file,
|
||||
grid_paramfile,
|
||||
platform,
|
||||
title,
|
||||
progress_bar=True,
|
||||
@@ -173,6 +174,12 @@ class Experiment:
|
||||
self.hyperparameters_dict = hyper.load(
|
||||
dictionary=dictionary,
|
||||
)
|
||||
elif grid_paramfile:
|
||||
grid_file = os.path.join(
|
||||
Folders.results, Files.grid_output(score_name, model_name)
|
||||
)
|
||||
with open(grid_file) as f:
|
||||
self.hyperparameters_dict = json.load(f)
|
||||
else:
|
||||
self.hyperparameters_dict = hyper.fill(
|
||||
dictionary=dictionary,
|
||||
@@ -354,7 +361,7 @@ class GridSearch:
|
||||
for item in self.datasets:
|
||||
data[item] = self.results[item]
|
||||
with open(self.output_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
def _store_result(self, name, grid, duration):
|
||||
d_message = f"{duration:.3f} s"
|
||||
|
Reference in New Issue
Block a user