Complete tests of Experiment class

This commit is contained in:
2022-04-23 19:26:57 +02:00
parent b9a3faa604
commit 5848484d53
3 changed files with 53 additions and 3 deletions

View File

@@ -336,8 +336,6 @@ class Experiment:
self._output_results()
self.duration = time.time() - now
self._output_results()
if self.progress_bar:
print(f"Results in {self.output_file}")
class GridSearch:

View File

@@ -153,6 +153,9 @@ if __name__ == "__main__":
result_file = job.get_output_file()
report = Report(result_file)
report.report()
if dataset is not None:
print(f"Partial result file removed: {result_file}")
os.remove(result_file)
else:
print(f"Results in {job.get_output_file()}")

View File

@@ -1,4 +1,5 @@
import os
import json
import unittest
from ..Models import Models
from ..Experiments import Experiment, Datasets
@@ -22,11 +23,13 @@ class ExperimentTest(unittest.TestCase):
"platform": "test",
"title": "Test",
"progress_bar": False,
"folds": 1,
"folds": 2,
}
return Experiment(**params)
def tearDown(self) -> None:
if os.path.exists(self.exp.get_output_file()):
os.remove(self.exp.get_output_file())
return super().tearDown()
def test_build_hyperparams_and_grid_file(self):
@@ -59,3 +62,49 @@ class ExperimentTest(unittest.TestCase):
file_name.startswith("results/results_accuracy_STree_test_")
)
self.assertTrue(file_name.endswith("_0.json"))
def test_exception_n_fold_crossval(self):
self.exp.do_experiment()
with self.assertRaises(ValueError):
self.exp._n_fold_crossval([], [], {})
def test_do_experiment(self):
self.exp.do_experiment()
file_name = self.exp.get_output_file()
with open(file_name) as f:
data = json.load(f)
# Check Header
expected = {
"score_name": "accuracy",
"title": "Test",
"model": "STree",
"stratified": False,
"folds": 2,
"seeds": [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1],
"platform": "test",
}
for key, value in expected.items():
self.assertEqual(data[key], value)
# Check Results
expected_results = [
{
"dataset": "balance-scale",
"samples": 625,
"features": 4,
"classes": 3,
"hyperparameters": {},
},
{
"dataset": "balloons",
"samples": 16,
"features": 4,
"classes": 2,
"hyperparameters": {},
},
]
for expected_result, computed_result in zip(
expected_results, data["results"]
):
for key, value in expected_result.items():
self.assertEqual(computed_result[key], value)