mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-17 16:35:54 +00:00
Add grid file param to cross validation experiment
This commit is contained in:
@@ -140,6 +140,7 @@ class Experiment:
|
||||
datasets,
|
||||
hyperparams_dict,
|
||||
hyperparams_file,
|
||||
grid_paramfile,
|
||||
platform,
|
||||
title,
|
||||
progress_bar=True,
|
||||
@@ -173,6 +174,12 @@ class Experiment:
|
||||
self.hyperparameters_dict = hyper.load(
|
||||
dictionary=dictionary,
|
||||
)
|
||||
elif grid_paramfile:
|
||||
grid_file = os.path.join(
|
||||
Folders.results, Files.grid_output(score_name, model_name)
|
||||
)
|
||||
with open(grid_file) as f:
|
||||
self.hyperparameters_dict = json.load(f)
|
||||
else:
|
||||
self.hyperparameters_dict = hyper.fill(
|
||||
dictionary=dictionary,
|
||||
@@ -354,7 +361,7 @@ class GridSearch:
|
||||
for item in self.datasets:
|
||||
data[item] = self.results[item]
|
||||
with open(self.output_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
def _store_result(self, name, grid, duration):
|
||||
d_message = f"{duration:.3f} s"
|
||||
|
@@ -75,6 +75,7 @@ results["linear"]["multiclass_strategy"] = ["ovo"]
|
||||
del results["linear"]["gamma"]
|
||||
del results["liblinear"]["gamma"]
|
||||
results["rbf"]["gamma"].append("scale")
|
||||
results["poly"]["gamma"].append("scale")
|
||||
results["poly"]["multiclass_strategy"].append("ovo")
|
||||
for kernel in kernels:
|
||||
results[kernel]["C"].append(1.0)
|
||||
@@ -101,5 +102,5 @@ for item in results:
|
||||
file_name = Files.grid_input("accuracy", "ODTE")
|
||||
file_output = os.path.join(Folders.results, file_name)
|
||||
with open(file_output, "w") as f:
|
||||
json.dump(output, f)
|
||||
json.dump(output, f, indent=4)
|
||||
print(f"Grid values saved to {file_output}")
|
||||
|
20
src/main.py
20
src/main.py
@@ -49,7 +49,20 @@ def parse_arguments():
|
||||
"-p", "--hyperparameters", type=str, required=False, default="{}"
|
||||
)
|
||||
ap.add_argument(
|
||||
"-f", "--paramfile", type=bool, required=False, default=False
|
||||
"-f",
|
||||
"--paramfile",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
help="Use best hyperparams file?",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-g",
|
||||
"--grid_paramfile",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
help="Use grid searched hyperparams file?",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--title", type=str, required=True, help="experiment title"
|
||||
@@ -97,6 +110,7 @@ def parse_arguments():
|
||||
args.quiet,
|
||||
args.hyperparameters,
|
||||
args.paramfile,
|
||||
args.grid_paramfile,
|
||||
args.report,
|
||||
args.title,
|
||||
args.dataset,
|
||||
@@ -112,11 +126,14 @@ def parse_arguments():
|
||||
quiet,
|
||||
hyperparameters,
|
||||
paramfile,
|
||||
grid_paramfile,
|
||||
report,
|
||||
experiment_title,
|
||||
dataset,
|
||||
) = parse_arguments()
|
||||
report = report or dataset is not None
|
||||
if grid_paramfile:
|
||||
paramfile = False
|
||||
job = Experiment(
|
||||
score_name=score,
|
||||
model_name=model,
|
||||
@@ -124,6 +141,7 @@ job = Experiment(
|
||||
datasets=Datasets(dataset=dataset),
|
||||
hyperparams_dict=hyperparameters,
|
||||
hyperparams_file=paramfile,
|
||||
grid_paramfile=grid_paramfile,
|
||||
progress_bar=not quiet,
|
||||
platform=platform,
|
||||
title=experiment_title,
|
||||
|
Reference in New Issue
Block a user