Enhance error msgs in be_main

This commit is contained in:
2022-05-09 11:37:53 +02:00
parent ca96d05124
commit 7501ce7761
3 changed files with 78 additions and 21 deletions

View File

@@ -222,8 +222,11 @@ class Experiment:
grid_file = os.path.join(
Folders.results, Files.grid_output(score_name, model_name)
)
with open(grid_file) as f:
self.hyperparameters_dict = json.load(f)
try:
with open(grid_file) as f:
self.hyperparameters_dict = json.load(f)
except FileNotFoundError:
raise ValueError(f"{grid_file} does not exist")
else:
self.hyperparameters_dict = hyper.fill(
dictionary=dictionary,

View File

@@ -18,20 +18,20 @@ def main(args_test=None):
report = args.report or args.dataset is not None
if args.grid_paramfile:
args.paramfile = False
job = Experiment(
score_name=args.score,
model_name=args.model,
stratified=args.stratified,
datasets=Datasets(dataset_name=args.dataset),
hyperparams_dict=args.hyperparameters,
hyperparams_file=args.paramfile,
grid_paramfile=args.grid_paramfile,
progress_bar=not args.quiet,
platform=args.platform,
title=args.title,
folds=args.n_folds,
)
try:
job = Experiment(
score_name=args.score,
model_name=args.model,
stratified=args.stratified,
datasets=Datasets(dataset_name=args.dataset),
hyperparams_dict=args.hyperparameters,
hyperparams_file=args.paramfile,
grid_paramfile=args.grid_paramfile,
progress_bar=not args.quiet,
platform=args.platform,
title=args.title,
folds=args.n_folds,
)
job.do_experiment()
except ValueError as e:
print(e)

View File

@@ -1,6 +1,8 @@
import os
from io import StringIO
from unittest.mock import patch
from ...Results import Report
from ...Utils import Files, Folders
from ..TestBase import TestBase
@@ -14,7 +16,7 @@ class BeMainTest(TestBase):
self.remove_files(self.files, ".")
return super().tearDown()
def test_be_benchmark_dataset(self):
def test_be_main_dataset(self):
stdout, _ = self.execute_script(
"be_main",
["-m", "STree", "-d", "balloons", "--title", "test"],
@@ -25,7 +27,7 @@ class BeMainTest(TestBase):
lines_to_compare=[0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13],
)
def test_be_benchmark_complete(self):
def test_be_main_complete(self):
stdout, _ = self.execute_script(
"be_main",
["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"],
@@ -37,7 +39,7 @@ class BeMainTest(TestBase):
stdout, "be_main_complete", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
)
def test_be_benchmark_no_report(self):
def test_be_main_no_report(self):
stdout, _ = self.execute_script(
"be_main",
["-s", self.score, "-m", "STree", "--title", "test"],
@@ -54,7 +56,7 @@ class BeMainTest(TestBase):
[0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14],
)
def test_be_benchmark_best_params(self):
def test_be_main_best_params(self):
stdout, _ = self.execute_script(
"be_main",
[
@@ -77,7 +79,59 @@ class BeMainTest(TestBase):
stdout, "be_main_best", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
)
def test_be_benchmark_grid_params(self):
def test_be_main_best_params_non_existent(self):
model = "GBC"
stdout, stderr = self.execute_script(
"be_main",
[
"-s",
self.score,
"-m",
model,
"--title",
"test",
"-f",
"1",
"-r",
"1",
],
)
self.assertEqual(stderr.getvalue(), "")
file_name = os.path.join(
Folders.results, Files.best_results(self.score, model)
)
self.assertEqual(
stdout.getvalue(),
f"{file_name} does not exist\n",
)
def test_be_main_grid_non_existent(self):
model = "GBC"
stdout, stderr = self.execute_script(
"be_main",
[
"-s",
self.score,
"-m",
model,
"--title",
"test",
"-g",
"1",
"-r",
"1",
],
)
self.assertEqual(stderr.getvalue(), "")
file_name = os.path.join(
Folders.results, Files.grid_output(self.score, model)
)
self.assertEqual(
stdout.getvalue(),
f"{file_name} does not exist\n",
)
def test_be_main_grid_params(self):
stdout, _ = self.execute_script(
"be_main",
[
@@ -100,7 +154,7 @@ class BeMainTest(TestBase):
stdout, "be_main_grid", [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
)
def test_be_benchmark_no_data(self):
def test_be_main_no_data(self):
stdout, _ = self.execute_script(
"be_main", ["-m", "STree", "-d", "unknown", "--title", "test"]
)