Add grid file param to cross validation experiment

This commit is contained in:
2022-03-09 16:47:52 +01:00
parent 14f0886662
commit f031d67668
5 changed files with 110 additions and 88 deletions

View File

@@ -140,6 +140,7 @@ class Experiment:
datasets,
hyperparams_dict,
hyperparams_file,
grid_paramfile,
platform,
title,
progress_bar=True,
@@ -173,6 +174,12 @@ class Experiment:
self.hyperparameters_dict = hyper.load(
dictionary=dictionary,
)
elif grid_paramfile:
grid_file = os.path.join(
Folders.results, Files.grid_output(score_name, model_name)
)
with open(grid_file) as f:
self.hyperparameters_dict = json.load(f)
else:
self.hyperparameters_dict = hyper.fill(
dictionary=dictionary,
@@ -354,7 +361,7 @@ class GridSearch:
for item in self.datasets:
data[item] = self.results[item]
with open(self.output_file, "w") as f:
json.dump(data, f)
json.dump(data, f, indent=4)
def _store_result(self, name, grid, duration):
d_message = f"{duration:.3f} s"

View File

@@ -75,6 +75,7 @@ results["linear"]["multiclass_strategy"] = ["ovo"]
del results["linear"]["gamma"]
del results["liblinear"]["gamma"]
results["rbf"]["gamma"].append("scale")
results["poly"]["gamma"].append("scale")
results["poly"]["multiclass_strategy"].append("ovo")
for kernel in kernels:
results[kernel]["C"].append(1.0)
@@ -101,5 +102,5 @@ for item in results:
file_name = Files.grid_input("accuracy", "ODTE")
file_output = os.path.join(Folders.results, file_name)
with open(file_output, "w") as f:
json.dump(output, f)
json.dump(output, f, indent=4)
print(f"Grid values saved to {file_output}")

View File

@@ -49,7 +49,20 @@ def parse_arguments():
"-p", "--hyperparameters", type=str, required=False, default="{}"
)
ap.add_argument(
"-f", "--paramfile", type=bool, required=False, default=False
"-f",
"--paramfile",
type=bool,
required=False,
default=False,
help="Use best hyperparams file?",
)
ap.add_argument(
"-g",
"--grid_paramfile",
type=bool,
required=False,
default=False,
help="Use grid searched hyperparams file?",
)
ap.add_argument(
"--title", type=str, required=True, help="experiment title"
@@ -97,6 +110,7 @@ def parse_arguments():
args.quiet,
args.hyperparameters,
args.paramfile,
args.grid_paramfile,
args.report,
args.title,
args.dataset,
@@ -112,11 +126,14 @@ def parse_arguments():
quiet,
hyperparameters,
paramfile,
grid_paramfile,
report,
experiment_title,
dataset,
) = parse_arguments()
report = report or dataset is not None
if grid_paramfile:
paramfile = False
job = Experiment(
score_name=score,
model_name=model,
@@ -124,6 +141,7 @@ job = Experiment(
datasets=Datasets(dataset=dataset),
hyperparams_dict=hyperparameters,
hyperparams_file=paramfile,
grid_paramfile=grid_paramfile,
progress_bar=not quiet,
platform=platform,
title=experiment_title,