Add be_build_grid and fix some scripts issues

This commit is contained in:
2022-05-05 20:19:50 +02:00
parent 5bcd4beca9
commit 1cefc51870
7 changed files with 42 additions and 17 deletions

View File

@@ -26,9 +26,11 @@ class EnvData:
class EnvDefault(argparse.Action): class EnvDefault(argparse.Action):
# Thanks to https://stackoverflow.com/users/445507/russell-heilling # Thanks to https://stackoverflow.com/users/445507/russell-heilling
def __init__(self, envvar, required=True, default=None, **kwargs): def __init__(
self, envvar, required=True, default=None, mandatory=False, **kwargs
):
self._args = EnvData.load() self._args = EnvData.load()
if required: if required and not mandatory:
default = self._args[envvar] default = self._args[envvar]
required = False required = False
super(EnvDefault, self).__init__( super(EnvDefault, self).__init__(

View File

@@ -144,6 +144,7 @@ class BestResults:
score=self.score_name, model=self.model score=self.score_name, model=self.model
) )
all_files = sorted(list(os.walk(Folders.results))) all_files = sorted(list(os.walk(Folders.results)))
found = False
for root, _, files in tqdm( for root, _, files in tqdm(
all_files, desc="files", disable=self.quiet all_files, desc="files", disable=self.quiet
): ):
@@ -153,6 +154,9 @@ class BestResults:
with open(file_name) as fp: with open(file_name) as fp:
data = json.load(fp) data = json.load(fp)
self._process_datafile(results, data, name) self._process_datafile(results, data, name)
found = True
if not found:
raise ValueError(f"No results found")
# Build best results json file # Build best results json file
output = {} output = {}
datasets = Datasets() datasets = Datasets()
@@ -374,10 +378,6 @@ class GridSearch:
self.grid_file = os.path.join( self.grid_file = os.path.join(
Folders.results, Files.grid_input(score_name, model_name) Folders.results, Files.grid_input(score_name, model_name)
) )
with open(self.grid_file) as f:
self.grid = json.load(f)
self.duration = 0
self._init_data()
def get_output_file(self): def get_output_file(self):
return self.output_file return self.output_file
@@ -426,6 +426,10 @@ class GridSearch:
self.results[name] = [score, hyperparameters, message] self.results[name] = [score, hyperparameters, message]
def do_gridsearch(self): def do_gridsearch(self):
with open(self.grid_file) as f:
self.grid = json.load(f)
self.duration = 0
self._init_data()
now = time.time() now = time.time()
loop = tqdm( loop = tqdm(
list(self.datasets), list(self.datasets),

View File

@@ -1124,7 +1124,7 @@ class Summary:
color2 = TextColor.LINE2 color2 = TextColor.LINE2
print(color1, end="") print(color1, end="")
print( print(
f"{'Date':10s} {'File':{max_file}s} {'Score':7s} {'Time(h)':7s} " f"{'Date':10s} {'File':{max_file}s} {'Score':8s} {'Time(h)':7s} "
f"{'Title':s}" f"{'Title':s}"
) )
print( print(

View File

@@ -13,7 +13,11 @@ def main():
args = arguments.parse() args = arguments.parse()
datasets = Datasets() datasets = Datasets()
best = BestResults(args.score, args.model, datasets) best = BestResults(args.score, args.model, datasets)
try:
best.build() best.build()
except ValueError:
print("** No results found **")
else:
if args.report: if args.report:
report = ReportBest(args.score, args.model, best=True, grid=False) report = ReportBest(args.score, args.model, best=True, grid=False)
report.report() report.report()

View File

@@ -2,9 +2,17 @@
import os import os
import json import json
from benchmark.Utils import Files, Folders from benchmark.Utils import Files, Folders
from benchmark.Arguments import Arguments
"""Build sample grid input file for the model with data taken from the
input grid used optimizing STree
"""
def main(): def main():
arguments = Arguments()
arguments.xset("model", mandatory=True).xset("score", mandatory=True)
args = arguments.parse()
data = [ data = [
'{"C": 1e4, "gamma": 0.1, "kernel": "rbf"}', '{"C": 1e4, "gamma": 0.1, "kernel": "rbf"}',
'{"C": 7, "gamma": 0.14, "kernel": "rbf"}', '{"C": 7, "gamma": 0.14, "kernel": "rbf"}',
@@ -105,8 +113,8 @@ def main():
output.append(results_tmp) output.append(results_tmp)
# save results # save results
file_name = Files.grid_input("accuracy", "ODTE") file_name = Files.grid_input(args.score, args.model)
file_output = os.path.join(Folders.results, file_name) file_output = os.path.join(Folders.results, file_name)
with open(file_output, "w") as f: with open(file_output, "w") as f:
json.dump(output, f, indent=4) json.dump(output, f, indent=4)
print(f"Grid values saved to {file_output}") print(f"Generated grid input file to {file_output}")

View File

@@ -8,9 +8,11 @@ from benchmark.Arguments import Arguments
def main(): def main():
arguments = Arguments() arguments = Arguments()
arguments.xset("score").xset("platform").xset("model").xset("n_folds") arguments.xset("score").xset("platform").xset("model")
arguments.xset("quiet").xset("stratified").xset("dataset") arguments.xset("quiet").xset("stratified").xset("dataset").xset("n_folds")
args = arguments.parse() args = arguments.parse()
if not args.quiet:
print(f"Perform grid search with {args.model} model")
job = GridSearch( job = GridSearch(
score_name=args.score, score_name=args.score,
model_name=args.model, model_name=args.model,
@@ -18,6 +20,10 @@ def main():
datasets=Datasets(dataset_name=args.dataset), datasets=Datasets(dataset_name=args.dataset),
progress_bar=not args.quiet, progress_bar=not args.quiet,
platform=args.platform, platform=args.platform,
folds=args.folds, folds=args.n_folds,
) )
try:
job.do_gridsearch() job.do_gridsearch()
except FileNotFoundError:
print(f"** The grid input file [{job.grid_file}] could not be found")
print("")

View File

@@ -67,6 +67,7 @@ setuptools.setup(
"be_benchmark=benchmark.scripts.be_benchmark:main", "be_benchmark=benchmark.scripts.be_benchmark:main",
"be_best=benchmark.scripts.be_best:main", "be_best=benchmark.scripts.be_best:main",
"be_build_best=benchmark.scripts.be_build_best:main", "be_build_best=benchmark.scripts.be_build_best:main",
"be_build_grid=benchmark.scripts.be_build_grid:main",
"be_grid=benchmark.scripts.be_grid:main", "be_grid=benchmark.scripts.be_grid:main",
"be_pair_check=benchmark.scripts.be_pair_check:main", "be_pair_check=benchmark.scripts.be_pair_check:main",
"be_print_strees=benchmark.scripts.be_print_strees:main", "be_print_strees=benchmark.scripts.be_print_strees:main",