mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-16 16:05:54 +00:00
Begin be_main tests
This commit is contained in:
@@ -92,7 +92,10 @@ class Datasets:
|
||||
self.data_sets = [dataset_name]
|
||||
|
||||
def load(self, name):
|
||||
return self.dataset.load(name)
|
||||
try:
|
||||
return self.dataset.load(name)
|
||||
except FileNotFoundError:
|
||||
raise ValueError(f"Unknown dataset: {name}")
|
||||
|
||||
def __iter__(self) -> Diterator:
|
||||
return Diterator(self.data_sets)
|
||||
|
@@ -1,4 +1,3 @@
|
||||
from multiprocessing.sharedctypes import Value
|
||||
import os
|
||||
from operator import itemgetter
|
||||
import math
|
||||
|
@@ -31,13 +31,17 @@ def main(args_test=None):
|
||||
title=args.title,
|
||||
folds=args.n_folds,
|
||||
)
|
||||
job.do_experiment()
|
||||
if report:
|
||||
result_file = job.get_output_file()
|
||||
report = Report(result_file)
|
||||
report.report()
|
||||
if args.dataset is not None:
|
||||
print(f"Partial result file removed: {result_file}")
|
||||
os.remove(result_file)
|
||||
try:
|
||||
job.do_experiment()
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
else:
|
||||
print(f"Results in {job.get_output_file()}")
|
||||
if report:
|
||||
result_file = job.get_output_file()
|
||||
report = Report(result_file)
|
||||
report.report()
|
||||
if args.dataset is not None:
|
||||
print(f"Partial result file removed: {result_file}")
|
||||
os.remove(result_file)
|
||||
else:
|
||||
print(f"Results in {job.get_output_file()}")
|
||||
|
@@ -9,7 +9,6 @@ from ..Results import Benchmark
|
||||
|
||||
class BenchmarkTest(TestBase):
|
||||
def tearDown(self) -> None:
|
||||
benchmark = Benchmark("accuracy", visualize=False)
|
||||
files = []
|
||||
for score in ["accuracy", "unknown"]:
|
||||
files.append(Files.exreport(score))
|
||||
|
@@ -45,6 +45,18 @@ class DatasetTest(TestBase):
|
||||
self.assertSequenceEqual(computed, value)
|
||||
self.set_env(".env.dist")
|
||||
|
||||
def test_load_dataset(self):
|
||||
dt = Datasets()
|
||||
X, y = dt.load("balance-scale")
|
||||
self.assertSequenceEqual(X.shape, (625, 4))
|
||||
self.assertSequenceEqual(y.shape, (625,))
|
||||
|
||||
def test_load_unknown_dataset(self):
|
||||
dt = Datasets()
|
||||
with self.assertRaises(ValueError) as msg:
|
||||
dt.load("unknown")
|
||||
self.assertEqual(str(msg.exception), "Unknown dataset: unknown")
|
||||
|
||||
def test_Datasets_subset(self):
|
||||
test = {
|
||||
".env.dist": "balloons",
|
||||
|
@@ -18,6 +18,7 @@ from .scripts.Be_Summary_test import BeSummaryTest
|
||||
from .scripts.Be_Grid_test import BeGridTest
|
||||
from .scripts.Be_Best_test import BeBestTest
|
||||
from .scripts.Be_Benchmark_test import BeBenchmarkTest
|
||||
from .scripts.Be_Main_test import BeMainTest
|
||||
|
||||
all = [
|
||||
"UtilTest",
|
||||
@@ -40,4 +41,5 @@ all = [
|
||||
"BeGridTest",
|
||||
"BeBestTest",
|
||||
"BeBenchmarkTest",
|
||||
"BeMainTest",
|
||||
]
|
||||
|
47
benchmark/tests/scripts/Be_Main_test.py
Normal file
47
benchmark/tests/scripts/Be_Main_test.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
from ...Utils import Folders
|
||||
from ..TestBase import TestBase
|
||||
|
||||
|
||||
class BeMainTest(TestBase):
|
||||
def setUp(self):
|
||||
self.prepare_scripts_env()
|
||||
self.score = "accuracy"
|
||||
|
||||
def tearDown(self) -> None:
|
||||
files = []
|
||||
|
||||
self.remove_files(files, Folders.exreport)
|
||||
return super().tearDown()
|
||||
|
||||
def test_be_benchmark_dataset(self):
|
||||
stdout, _ = self.execute_script(
|
||||
"be_main",
|
||||
[
|
||||
"-s",
|
||||
self.score,
|
||||
"-m",
|
||||
"STree",
|
||||
"-d",
|
||||
"balloons",
|
||||
"--title",
|
||||
"test",
|
||||
],
|
||||
)
|
||||
with open(os.path.join(self.test_files, "be_main_dataset.test")) as f:
|
||||
expected = f.read()
|
||||
n_line = 0
|
||||
# compare only report lines without date, time, duration...
|
||||
lines_to_compare = [0, 2, 3, 5, 6, 7, 8, 9, 11, 12, 13]
|
||||
for expected, computed in zip(
|
||||
expected.splitlines(), stdout.getvalue().splitlines()
|
||||
):
|
||||
if n_line in lines_to_compare:
|
||||
self.assertEqual(computed, expected, n_line)
|
||||
n_line += 1
|
||||
|
||||
def test_be_benchmark_no_data(self):
|
||||
stdout, _ = self.execute_script(
|
||||
"be_main", ["-m", "STree", "-d", "unknown", "--title", "test"]
|
||||
)
|
||||
self.assertEqual(stdout.getvalue(), "Unknown dataset: unknown\n")
|
15
benchmark/tests/test_files/be_main_dataset.test
Normal file
15
benchmark/tests/test_files/be_main_dataset.test
Normal file
@@ -0,0 +1,15 @@
|
||||
[94m***********************************************************************************************************************
|
||||
[94m* Report STree ver. 1.2.4 with 5 Folds cross validation and 10 random seeds. 2022-05-08 19:38:28 *
|
||||
[94m* test *
|
||||
[94m* Random seeds: [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] Stratified: False *
|
||||
[94m* Execution took 0.06 seconds, 0.00 hours, on iMac27 *
|
||||
[94m* Score is accuracy *
|
||||
[94m***********************************************************************************************************************
|
||||
|
||||
Dataset Samp Feat. Cls Nodes Leaves Depth Score Time Hyperparameters
|
||||
============================== ===== ===== === ======= ======= ======= =============== ================ ===============
|
||||
[96mballoons 16 4 2 4.64 2.82 2.66 0.663333±0.3009 0.000671±0.0001 {}
|
||||
[94m***********************************************************************************************************************
|
||||
[94m* Accuracy compared to stree_default (liblinear-ovr) .: 0.0165 *
|
||||
[94m***********************************************************************************************************************
|
||||
Partial result file removed: results/results_accuracy_STree_iMac27_2022-05-08_19:38:28_0.json
|
Reference in New Issue
Block a user