refactor remove iwss from results

This commit is contained in:
2022-05-08 22:50:09 +02:00
parent 4a5225d3dc
commit 09b2ede836
5 changed files with 53 additions and 18 deletions

View File

@@ -8,7 +8,7 @@ class BestResultTest(TestBase):
expected = { expected = {
"balance-scale": [ "balance-scale": [
0.98, 0.98,
{"splitter": "iwss", "max_features": "auto"}, {"splitter": "best", "max_features": "auto"},
"results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json", "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json",
], ],
"balloons": [ "balloons": [

View File

@@ -36,7 +36,7 @@ class ExperimentTest(TestBase):
expected = { expected = {
"balance-scale": [ "balance-scale": [
0.98, 0.98,
{"splitter": "iwss", "max_features": "auto"}, {"splitter": "best", "max_features": "auto"},
"results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json", "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json",
], ],
"balloons": [ "balloons": [

View File

@@ -1 +1 @@
{"balance-scale": [0.98, {"splitter": "iwss", "max_features": "auto"}, "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json"], "balloons": [0.86, {"C": 7, "gamma": 0.1, "kernel": "rbf", "max_iter": 10000.0, "multiclass_strategy": "ovr"}, "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"]} {"balance-scale": [0.98, {"splitter": "best", "max_features": "auto"}, "results_accuracy_STree_iMac27_2021-10-27_09:40:40_0.json"], "balloons": [0.86, {"C": 7, "gamma": 0.1, "kernel": "rbf", "max_iter": 10000.0, "multiclass_strategy": "ovr"}, "results_accuracy_STree_iMac27_2021-09-30_11:42:07_0.json"]}

View File

@@ -15,7 +15,7 @@
"features": 4, "features": 4,
"classes": 3, "classes": 3,
"hyperparameters": { "hyperparameters": {
"splitter": "iwss", "splitter": "best",
"max_features": "auto" "max_features": "auto"
}, },
"nodes": 11.08, "nodes": 11.08,
@@ -32,7 +32,7 @@
"features": 4, "features": 4,
"classes": 2, "classes": 2,
"hyperparameters": { "hyperparameters": {
"splitter": "iwss", "splitter": "best",
"max_features": "auto" "max_features": "auto"
}, },
"nodes": 4.12, "nodes": 4.12,

View File

@@ -1,5 +1,8 @@
import os import os
from io import StringIO
from unittest.mock import patch
from ...Utils import Folders from ...Utils import Folders
from ...Results import Report
from ..TestBase import TestBase from ..TestBase import TestBase
@@ -7,26 +10,16 @@ class BeMainTest(TestBase):
def setUp(self): def setUp(self):
self.prepare_scripts_env() self.prepare_scripts_env()
self.score = "accuracy" self.score = "accuracy"
self.files = []
def tearDown(self) -> None: def tearDown(self) -> None:
files = [] self.remove_files(self.files, ".")
self.remove_files(files, Folders.exreport)
return super().tearDown() return super().tearDown()
def test_be_benchmark_dataset(self): def test_be_benchmark_dataset(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", "be_main",
[ ["-m", "STree", "-d", "balloons", "--title", "test"],
"-s",
self.score,
"-m",
"STree",
"-d",
"balloons",
"--title",
"test",
],
) )
with open(os.path.join(self.test_files, "be_main_dataset.test")) as f: with open(os.path.join(self.test_files, "be_main_dataset.test")) as f:
expected = f.read() expected = f.read()
@@ -40,6 +33,48 @@ class BeMainTest(TestBase):
self.assertEqual(computed, expected, n_line) self.assertEqual(computed, expected, n_line)
n_line += 1 n_line += 1
def test_be_benchmark_complete(self):
stdout, _ = self.execute_script(
"be_main",
["-s", self.score, "-m", "STree", "--title", "test", "-r", "1"],
)
with open(os.path.join(self.test_files, "be_main_complete.test")) as f:
expected = f.read()
n_line = 0
# keep the report name to delete it after
self.files.append(stdout.getvalue().splitlines()[-1].split("in ")[1])
# compare only report lines without date, time, duration...
lines_to_compare = [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
for expected, computed in zip(
expected.splitlines(), stdout.getvalue().splitlines()
):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def test_be_benchmark_no_report(self):
stdout, _ = self.execute_script(
"be_main",
["-s", self.score, "-m", "STree", "--title", "test"],
)
with open(os.path.join(self.test_files, "be_main_complete.test")) as f:
expected = f.read()
# keep the report name to delete it after
report_name = stdout.getvalue().splitlines()[-1].split("in ")[1]
self.files.append(report_name)
report = Report(file_name=report_name)
with patch(self.output, new=StringIO()) as stdout:
report.report()
# compare only report lines without date, time, duration...
lines_to_compare = [0, 2, 3, 5, 6, 7, 8, 9, 12, 13, 14]
n_line = 0
for expected, computed in zip(
expected.splitlines(), stdout.getvalue().splitlines()
):
if n_line in lines_to_compare:
self.assertEqual(computed, expected, n_line)
n_line += 1
def test_be_benchmark_no_data(self): def test_be_benchmark_no_data(self):
stdout, _ = self.execute_script( stdout, _ = self.execute_script(
"be_main", ["-m", "STree", "-d", "unknown", "--title", "test"] "be_main", ["-m", "STree", "-d", "unknown", "--title", "test"]