Enhance error msgs in be_main

This commit is contained in:
2022-05-09 11:37:53 +02:00
parent ca96d05124
commit 7501ce7761
3 changed files with 78 additions and 21 deletions

View File

@@ -222,8 +222,11 @@ class Experiment:
grid_file = os.path.join( grid_file = os.path.join(
Folders.results, Files.grid_output(score_name, model_name) Folders.results, Files.grid_output(score_name, model_name)
) )
with open(grid_file) as f: try:
self.hyperparameters_dict = json.load(f) with open(grid_file) as f:
self.hyperparameters_dict = json.load(f)
except FileNotFoundError:
raise ValueError(f"{grid_file} does not exist")
else: else:
self.hyperparameters_dict = hyper.fill( self.hyperparameters_dict = hyper.fill(
dictionary=dictionary, dictionary=dictionary,

View File

@@ -18,20 +18,20 @@ def main(args_test=None):
report = args.report or args.dataset is not None report = args.report or args.dataset is not None
if args.grid_paramfile: if args.grid_paramfile:
args.paramfile = False args.paramfile = False
job = Experiment(
score_name=args.score,
model_name=args.model,
stratified=args.stratified,
datasets=Datasets(dataset_name=args.dataset),
hyperparams_dict=args.hyperparameters,
hyperparams_file=args.paramfile,
grid_paramfile=args.grid_paramfile,
progress_bar=not args.quiet,
platform=args.platform,
title=args.title,
folds=args.n_folds,
)
try: try:
job = Experiment(
score_name=args.score,
model_name=args.model,
stratified=args.stratified,
datasets=Datasets(dataset_name=args.dataset),
hyperparams_dict=args.hyperparameters,
hyperparams_file=args.paramfile,
grid_paramfile=args.grid_paramfile,
progress_bar=not args.quiet,
platform=args.platform,
title=args.title,
folds=args.n_folds,
)
job.do_experiment() job.do_experiment()
except ValueError as e: except ValueError as e:
print(e) print(e)

View File

@@ -1,6 +1,8 @@
import os
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
from ...Results import Report from ...Results import Report
from ...Utils import Files, Folders
from ..TestBase import TestBase from ..TestBase import TestBase
@@ -14,7 +16,7 @@ class BeMainTest(TestBase):
self.remove_files(self.files, ".") self.remove_files(self.files, ".")
return super().tearDown() return super().tearDown()
def test_be_benchmark_dataset(self): def test_be_main_dataset(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",
["-m", "STree", "-d", "balloons", "--title", "test"], ["-m", "STree", "-d", "balloons", "--title", "test"],
@@ -25,7 +27,7 @@ class BeMainTest(TestBase):
lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13], lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13],
) )
def test_be_benchmark_complete(self): def test_be_main_complete(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",
["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"], ["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"],
@@ -37,7 +39,7 @@ class BeMainTest(TestBase):
stdout, "be_main_complete", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14] stdout, "be_main_complete", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
) )
def test_be_benchmark_no_report(self): def test_be_main_no_report(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",
["-s", self.score, "-m", "STree", "--title", "test"], ["-s", self.score, "-m", "STree", "--title", "test"],
@@ -54,7 +56,7 @@ class BeMainTest(TestBase):
[0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14], [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14],
) )
def test_be_benchmark_best_params(self): def test_be_main_best_params(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",
[ [
@@ -77,7 +79,59 @@ class BeMainTest(TestBase):
stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14] stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
) )
def test_be_benchmark_grid_params(self): def test_be_main_best_params_non_existent(self):
model = "GBC"
stdout, stderr = self.execute_script(
"be_main",
[
"-s",
self.score,
"-m",
model,
"--title",
"test",
"-f",
"1",
"-r",
"1",
],
)
self.assertEqual(stderr.getvalue(), "")
file_name = os.path.join(
Folders.results, Files.best_results(self.score, model)
)
self.assertEqual(
stdout.getvalue(),
f"{file_name} does not exist\n",
)
def test_be_main_grid_non_existent(self):
model = "GBC"
stdout, stderr = self.execute_script(
"be_main",
[
"-s",
self.score,
"-m",
model,
"--title",
"test",
"-g",
"1",
"-r",
"1",
],
)
self.assertEqual(stderr.getvalue(), "")
file_name = os.path.join(
Folders.results, Files.grid_output(self.score, model)
)
self.assertEqual(
stdout.getvalue(),
f"{file_name} does not exist\n",
)
def test_be_main_grid_params(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",
[ [
@@ -100,7 +154,7 @@ class BeMainTest(TestBase):
stdout, "be_main_grid", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14] stdout, "be_main_grid", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
) )
def test_be_benchmark_no_data(self): def test_be_main_no_data(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", ["-m", "STree", "-d", "unknown", "--title", "test"] "be_main", ["-m", "STree", "-d", "unknown", "--title", "test"]
) )