Add be_build_grid and fix some scripts issues

This commit is contained in:
2022-05-05 20:19:50 +02:00
parent 5bcd4beca9
commit 1cefc51870
7 changed files with 42 additions and 17 deletions

View File

@@ -26,9 +26,11 @@ class EnvData:
class EnvDefault(argparse.Action):
# Thanks to https://stackoverflow.com/users/445507/russell-heilling
def __init__(self, envvar, required=True, default=None, **kwargs):
def __init__(
self, envvar, required=True, default=None, mandatory=False, **kwargs
):
self._args = EnvData.load()
if required:
if required and not mandatory:
default = self._args[envvar]
required = False
super(EnvDefault, self).__init__(

View File

@@ -144,6 +144,7 @@ class BestResults:
score=self.score_name, model=self.model
)
all_files = sorted(list(os.walk(Folders.results)))
found = False
for root, _, files in tqdm(
all_files, desc="files", disable=self.quiet
):
@@ -153,6 +154,9 @@ class BestResults:
with open(file_name) as fp:
data = json.load(fp)
self._process_datafile(results, data, name)
found = True
if not found:
raise ValueError(f"No results found")
# Build best results json file
output = {}
datasets = Datasets()
@@ -374,10 +378,6 @@ class GridSearch:
self.grid_file = os.path.join(
Folders.results, Files.grid_input(score_name, model_name)
)
with open(self.grid_file) as f:
self.grid = json.load(f)
self.duration = 0
self._init_data()
def get_output_file(self):
return self.output_file
@@ -426,6 +426,10 @@ class GridSearch:
self.results[name] = [score, hyperparameters, message]
def do_gridsearch(self):
with open(self.grid_file) as f:
self.grid = json.load(f)
self.duration = 0
self._init_data()
now = time.time()
loop = tqdm(
list(self.datasets),

View File

@@ -1124,7 +1124,7 @@ class Summary:
color2 = TextColor.LINE2
print(color1, end="")
print(
f"{'Date':10s} {'File':{max_file}s} {'Score':7s} {'Time(h)':7s} "
f"{'Date':10s} {'File':{max_file}s} {'Score':8s} {'Time(h)':7s} "
f"{'Title':s}"
)
print(

View File

@@ -13,7 +13,11 @@ def main():
args = arguments.parse()
datasets = Datasets()
best = BestResults(args.score, args.model, datasets)
try:
best.build()
except ValueError:
print("** No results found **")
else:
if args.report:
report = ReportBest(args.score, args.model, best=True, grid=False)
report.report()

View File

@@ -2,9 +2,17 @@
import os
import json
from benchmark.Utils import Files, Folders
from benchmark.Arguments import Arguments
"""Build sample grid input file for the model with data taken from the
input grid used optimizing STree
"""
def main():
arguments = Arguments()
arguments.xset("model", mandatory=True).xset("score", mandatory=True)
args = arguments.parse()
data = [
'{"C": 1e4, "gamma": 0.1, "kernel": "rbf"}',
'{"C": 7, "gamma": 0.14, "kernel": "rbf"}',
@@ -105,8 +113,8 @@ def main():
output.append(results_tmp)
# save results
file_name = Files.grid_input("accuracy", "ODTE")
file_name = Files.grid_input(args.score, args.model)
file_output = os.path.join(Folders.results, file_name)
with open(file_output, "w") as f:
json.dump(output, f, indent=4)
print(f"Grid values saved to {file_output}")
print(f"Generated grid input file to {file_output}")

View File

@@ -8,9 +8,11 @@ from benchmark.Arguments import Arguments
def main():
arguments = Arguments()
arguments.xset("score").xset("platform").xset("model").xset("n_folds")
arguments.xset("quiet").xset("stratified").xset("dataset")
arguments.xset("score").xset("platform").xset("model")
arguments.xset("quiet").xset("stratified").xset("dataset").xset("n_folds")
args = arguments.parse()
if not args.quiet:
print(f"Perform grid search with {args.model} model")
job = GridSearch(
score_name=args.score,
model_name=args.model,
@@ -18,6 +20,10 @@ def main():
datasets=Datasets(dataset_name=args.dataset),
progress_bar=not args.quiet,
platform=args.platform,
folds=args.folds,
folds=args.n_folds,
)
try:
job.do_gridsearch()
except FileNotFoundError:
print(f"** The grid input file [{job.grid_file}] could not be found")
print("")

View File

@@ -67,6 +67,7 @@ setuptools.setup(
"be_benchmark=benchmark.scripts.be_benchmark:main",
"be_best=benchmark.scripts.be_best:main",
"be_build_best=benchmark.scripts.be_build_best:main",
"be_build_grid=benchmark.scripts.be_build_grid:main",
"be_grid=benchmark.scripts.be_grid:main",
"be_pair_check=benchmark.scripts.be_pair_check:main",
"be_print_strees=benchmark.scripts.be_print_strees:main",