mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-15 23:45:54 +00:00
Complete tests of Experiment class
This commit is contained in:
@@ -336,8 +336,6 @@ class Experiment:
|
||||
self._output_results()
|
||||
self.duration = time.time() - now
|
||||
self._output_results()
|
||||
if self.progress_bar:
|
||||
print(f"Results in {self.output_file}")
|
||||
|
||||
|
||||
class GridSearch:
|
||||
|
@@ -153,6 +153,9 @@ if __name__ == "__main__":
|
||||
result_file = job.get_output_file()
|
||||
report = Report(result_file)
|
||||
report.report()
|
||||
|
||||
if dataset is not None:
|
||||
print(f"Partial result file removed: {result_file}")
|
||||
os.remove(result_file)
|
||||
else:
|
||||
print(f"Results in {job.get_output_file()}")
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import json
|
||||
import unittest
|
||||
from ..Models import Models
|
||||
from ..Experiments import Experiment, Datasets
|
||||
@@ -22,11 +23,13 @@ class ExperimentTest(unittest.TestCase):
|
||||
"platform": "test",
|
||||
"title": "Test",
|
||||
"progress_bar": False,
|
||||
"folds": 1,
|
||||
"folds": 2,
|
||||
}
|
||||
return Experiment(**params)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if os.path.exists(self.exp.get_output_file()):
|
||||
os.remove(self.exp.get_output_file())
|
||||
return super().tearDown()
|
||||
|
||||
def test_build_hyperparams_and_grid_file(self):
|
||||
@@ -59,3 +62,49 @@ class ExperimentTest(unittest.TestCase):
|
||||
file_name.startswith("results/results_accuracy_STree_test_")
|
||||
)
|
||||
self.assertTrue(file_name.endswith("_0.json"))
|
||||
|
||||
def test_exception_n_fold_crossval(self):
|
||||
self.exp.do_experiment()
|
||||
with self.assertRaises(ValueError):
|
||||
self.exp._n_fold_crossval([], [], {})
|
||||
|
||||
def test_do_experiment(self):
|
||||
self.exp.do_experiment()
|
||||
file_name = self.exp.get_output_file()
|
||||
with open(file_name) as f:
|
||||
data = json.load(f)
|
||||
# Check Header
|
||||
expected = {
|
||||
"score_name": "accuracy",
|
||||
"title": "Test",
|
||||
"model": "STree",
|
||||
"stratified": False,
|
||||
"folds": 2,
|
||||
"seeds": [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1],
|
||||
"platform": "test",
|
||||
}
|
||||
for key, value in expected.items():
|
||||
self.assertEqual(data[key], value)
|
||||
# Check Results
|
||||
expected_results = [
|
||||
{
|
||||
"dataset": "balance-scale",
|
||||
"samples": 625,
|
||||
"features": 4,
|
||||
"classes": 3,
|
||||
"hyperparameters": {},
|
||||
},
|
||||
{
|
||||
"dataset": "balloons",
|
||||
"samples": 16,
|
||||
"features": 4,
|
||||
"classes": 2,
|
||||
"hyperparameters": {},
|
||||
},
|
||||
]
|
||||
|
||||
for expected_result, computed_result in zip(
|
||||
expected_results, data["results"]
|
||||
):
|
||||
for key, value in expected_result.items():
|
||||
self.assertEqual(computed_result[key], value)
|
||||
|
Reference in New Issue
Block a user