Begin be_main tests

This commit is contained in:
2022-05-08 19:59:53 +02:00
parent e58901a307
commit 80eb9f1db7
8 changed files with 93 additions and 12 deletions

View File

@@ -92,7 +92,10 @@ class Datasets:
self.data_sets = [dataset_name] self.data_sets = [dataset_name]
def load(self, name): def load(self, name):
try:
return self.dataset.load(name) return self.dataset.load(name)
except FileNotFoundError:
raise ValueError(f"Unknown dataset: {name}")
def __iter__(self) -> Diterator: def __iter__(self) -> Diterator:
return Diterator(self.data_sets) return Diterator(self.data_sets)

View File

@@ -1,4 +1,3 @@
from multiprocessing.sharedctypes import Value
import os import os
from operator import itemgetter from operator import itemgetter
import math import math

View File

@@ -31,7 +31,11 @@ def main(args_test=None):
title=args.title, title=args.title,
folds=args.n_folds, folds=args.n_folds,
) )
try:
job.do_experiment() job.do_experiment()
except ValueError as e:
print(e)
else:
if report: if report:
result_file = job.get_output_file() result_file = job.get_output_file()
report = Report(result_file) report = Report(result_file)

View File

@@ -9,7 +9,6 @@ from ..Results import Benchmark
class BenchmarkTest(TestBase): class BenchmarkTest(TestBase):
def tearDown(self) -> None: def tearDown(self) -> None:
benchmark = Benchmark("accuracy", visualize=False)
files = [] files = []
for score in ["accuracy", "unknown"]: for score in ["accuracy", "unknown"]:
files.append(Files.exreport(score)) files.append(Files.exreport(score))

View File

@@ -45,6 +45,18 @@ class DatasetTest(TestBase):
self.assertSequenceEqual(computed, value) self.assertSequenceEqual(computed, value)
self.set_env(".env.dist") self.set_env(".env.dist")
def test_load_dataset(self):
dt = Datasets()
X, y = dt.load("balance-scale")
self.assertSequenceEqual(X.shape, (625, 4))
self.assertSequenceEqual(y.shape, (625,))
def test_load_unknown_dataset(self):
dt = Datasets()
with self.assertRaises(ValueError) as msg:
dt.load("unknown")
self.assertEqual(str(msg.exception), "Unknown dataset: unknown")
def test_Datasets_subset(self): def test_Datasets_subset(self):
test = { test = {
".env.dist": "balloons", ".env.dist": "balloons",

View File

@@ -18,6 +18,7 @@ from .scripts.Be_Summary_test import BeSummaryTest
from .scripts.Be_Grid_test import BeGridTest from .scripts.Be_Grid_test import BeGridTest
from .scripts.Be_Best_test import BeBestTest from .scripts.Be_Best_test import BeBestTest
from .scripts.Be_Benchmark_test import BeBenchmarkTest from .scripts.Be_Benchmark_test import BeBenchmarkTest
from .scripts.Be_Main_test import BeMainTest
all = [ all = [
"UtilTest", "UtilTest",
@@ -40,4 +41,5 @@ all = [
"BeGridTest", "BeGridTest",
"BeBestTest", "BeBestTest",
"BeBenchmarkTest", "BeBenchmarkTest",
"BeMainTest",
] ]

View File

@@ -0,0 +1,47 @@
import os
from ...Utils import Folders
from ..TestBase import TestBase
class BeMainTest(TestBase):
def setUp(self):
self.prepare_scripts_env()
self.score = "accuracy"
def tearDown(self) -> None:
files = []
self.remove_files(files, Folders.exreport)
return super().tearDown()
def test_be_benchmark_dataset(self):
stdout, _ = self.execute_script(
"be_main",
[
"-s",
self.score,
"-m",
"STree",
"-d",
"balloons",
"--title",
"test",
],
)
with open(os.path.join(self.test_files, "be_main_dataset.test")) as f:
expected = f.read()
n_line = 0
# compare only report lines without date, time, duration...
lines_to_compare = [0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13]
for expected, computed in zip(
expected.splitlines(), stdout.getvalue().splitlines()
):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def test_be_benchmark_no_data(self):
stdout, _ = self.execute_script(
"be_main", ["-m", "STree", "-d", "unknown", "--title", "test"]
)
self.assertEqual(stdout.getvalue(), "Unknown dataset: unknown\n")

View File

@@ -0,0 +1,15 @@
***********************************************************************************************************************
* Report STree ver. 1.2.4 with 5 Folds cross validation and 10 random seeds. 2022-05-08 19:38:28 *
* test *
* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
* Execution took 0.06 seconds, 0.00 hours, on iMac27 *
* Score is accuracy *
***********************************************************************************************************************
Dataset Samp Feat. Cls Nodes Leaves Depth Score Time Hyperparameters
============================== ===== ===== === ======= ======= ======= =============== ================ ===============
balloons 16 4 2 4.64 2.82 2.66 0.663333±0.3009 0.000671±0.0001 {}
***********************************************************************************************************************
* Accuracy compared to stree_default (liblinear-ovr) .: 0.0165 *
***********************************************************************************************************************
Partial result file removed: results/results_accuracy_STree_iMac27_2022-05-08_19:38:28_0.json