Fix tests

This commit is contained in:
2023-01-06 14:29:52 +01:00
parent 9ba6c55d49
commit d854d9ddf1
10 changed files with 82 additions and 34 deletions

View File

@@ -28,6 +28,12 @@ class DatasetsArff:
def folder(): def folder():
return "datasets" return "datasets"
@staticmethod
def get_range_features(X, c_features):
if c_features.strip() == "all":
return list(range(X.shape[1]))
return json.loads(c_features)
def load(self, name, class_name): def load(self, name, class_name):
file_name = os.path.join(self.folder(), self.dataset_names(name)) file_name = os.path.join(self.folder(), self.dataset_names(name))
data = arff.loadarff(file_name) data = arff.loadarff(file_name)
@@ -51,6 +57,10 @@ class DatasetsTanveer:
def folder(): def folder():
return "data" return "data"
@staticmethod
def get_range_features(X, name):
return []
def load(self, name, *args): def load(self, name, *args):
file_name = os.path.join(self.folder(), self.dataset_names(name)) file_name = os.path.join(self.folder(), self.dataset_names(name))
data = pd.read_csv( data = pd.read_csv(
@@ -76,6 +86,10 @@ class DatasetsSurcov:
def folder(): def folder():
return "datasets" return "datasets"
@staticmethod
def get_range_features(X, name):
return []
def load(self, name, *args): def load(self, name, *args):
file_name = os.path.join(self.folder(), self.dataset_names(name)) file_name = os.path.join(self.folder(), self.dataset_names(name))
data = pd.read_csv( data = pd.read_csv(
@@ -179,16 +193,13 @@ class Datasets:
} }
def load(self, name, dataframe=False): def load(self, name, dataframe=False):
def get_range_features(X, name):
c_features = self.continuous_features[name]
if c_features.strip() == "all":
return list(range(X.shape[1]))
return json.loads(c_features)
try: try:
class_name = self.class_names[self.data_sets.index(name)] class_name = self.class_names[self.data_sets.index(name)]
X, y = self.dataset.load(name, class_name) X, y = self.dataset.load(name, class_name)
self.continuous_features_dataset = get_range_features(X, name) self.continuous_features_dataset = self.dataset.get_range_features(
X, self.continuous_features[name]
)
if self.discretize: if self.discretize:
X = self.discretize_dataset(X, y) X = self.discretize_dataset(X, y)
self.build_states(name, X) self.build_states(name, X)

View File

@@ -7,5 +7,4 @@ stratified=0
source_data=Tanveer source_data=Tanveer
seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
discretize=0 discretize=0
ignore_nan=0 ignore_nan=0

View File

@@ -7,4 +7,4 @@ stratified=0
source_data=Tanveer source_data=Tanveer
seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
discretize=0 discretize=0
ignore_nan=0 ignore_nan=0

View File

@@ -1,4 +1,3 @@
import shutil
from .TestBase import TestBase from .TestBase import TestBase
from ..Experiments import Randomized from ..Experiments import Randomized
from ..Datasets import Datasets from ..Datasets import Datasets
@@ -17,10 +16,6 @@ class DatasetTest(TestBase):
self.set_env(".env.dist") self.set_env(".env.dist")
return super().tearDown() return super().tearDown()
@staticmethod
def set_env(env):
shutil.copy(env, ".env")
def test_Randomized(self): def test_Randomized(self):
expected = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] expected = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
self.assertSequenceEqual(Randomized.seeds(), expected) self.assertSequenceEqual(Randomized.seeds(), expected)

View File

@@ -8,10 +8,10 @@ class ExperimentTest(TestBase):
def setUp(self): def setUp(self):
self.exp = self.build_exp() self.exp = self.build_exp()
def build_exp(self, hyperparams=False, grid=False): def build_exp(self, hyperparams=False, grid=False, model="STree"):
params = { params = {
"score_name": "accuracy", "score_name": "accuracy",
"model_name": "STree", "model_name": model,
"stratified": "0", "stratified": "0",
"datasets": Datasets(), "datasets": Datasets(),
"hyperparams_dict": "{}", "hyperparams_dict": "{}",
@@ -21,6 +21,7 @@ class ExperimentTest(TestBase):
"title": "Test", "title": "Test",
"progress_bar": False, "progress_bar": False,
"folds": 2, "folds": 2,
"ignore_nan": False,
} }
return Experiment(**params) return Experiment(**params)
@@ -31,6 +32,7 @@ class ExperimentTest(TestBase):
], ],
".", ".",
) )
self.set_env(".env.dist")
return super().tearDown() return super().tearDown()
def test_build_hyperparams_file(self): def test_build_hyperparams_file(self):
@@ -89,7 +91,7 @@ class ExperimentTest(TestBase):
def test_exception_n_fold_crossval(self): def test_exception_n_fold_crossval(self):
self.exp.do_experiment() self.exp.do_experiment()
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.exp._n_fold_crossval([], [], {}) self.exp._n_fold_crossval("", [], [], {})
def test_do_experiment(self): def test_do_experiment(self):
self.exp.do_experiment() self.exp.do_experiment()
@@ -131,3 +133,27 @@ class ExperimentTest(TestBase):
): ):
for key, value in expected_result.items(): for key, value in expected_result.items():
self.assertEqual(computed_result[key], value) self.assertEqual(computed_result[key], value)
def test_build_fit_parameters(self):
self.set_env(".env.arff")
expected = {
"state_names": {
"sepallength": [0, 1, 2],
"sepalwidth": [0, 1, 3, 4],
"petallength": [0, 1, 2, 3],
"petalwidth": [0, 1, 2, 3],
},
"features": [
"sepallength",
"sepalwidth",
"petallength",
"petalwidth",
],
}
exp = self.build_exp(model="TAN")
X, y = exp.datasets.load("iris")
computed = exp._build_fit_params("iris")
for key, value in expected["state_names"].items():
self.assertEqual(computed["state_names"][key], value)
for feature in expected["features"]:
self.assertIn(feature, computed["features"])

View File

@@ -4,6 +4,7 @@ import pathlib
import sys import sys
import csv import csv
import unittest import unittest
import shutil
from importlib import import_module from importlib import import_module
from io import StringIO from io import StringIO
from unittest.mock import patch from unittest.mock import patch
@@ -19,6 +20,10 @@ class TestBase(unittest.TestCase):
self.stree_version = "1.2.4" self.stree_version = "1.2.4"
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@staticmethod
def set_env(env):
shutil.copy(env, ".env")
def remove_files(self, files, folder): def remove_files(self, files, folder):
for file_name in files: for file_name in files:
file_name = os.path.join(folder, file_name) file_name = os.path.join(folder, file_name)

View File

@@ -180,6 +180,7 @@ class UtilTest(TestBase):
"source_data": "Tanveer", "source_data": "Tanveer",
"seeds": "[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]", "seeds": "[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]",
"discretize": "0", "discretize": "0",
"ignore_nan": "0",
} }
computed = EnvData().load() computed = EnvData().load()
self.assertDictEqual(computed, expected) self.assertDictEqual(computed, expected)
@@ -191,8 +192,16 @@ class UtilTest(TestBase):
"n_folds": 5, "n_folds": 5,
"model": "STree", "model": "STree",
"stratified": "0", "stratified": "0",
"ignore_nan": "0",
} }
ap = argparse.ArgumentParser() ap = argparse.ArgumentParser()
ap.add_argument(
"--ignore-nan",
action=EnvDefault,
envvar="ignore_nan",
required=True,
help="Ignore nan results",
)
ap.add_argument( ap.add_argument(
"-s", "-s",
"--score", "--score",

View File

@@ -1,2 +1,2 @@
iris,class iris,class,all
wine,class wine,class,[0, 1]

View File

@@ -1,25 +1,28 @@
1;1;"Datasets used in benchmark ver. 0.2.0" 1;1;"Datasets used in benchmark ver. 0.4.0"
2;1;" Default score accuracy" 2;1;" Default score accuracy"
2;2;"Cross validation" 2;2;"Cross validation"
2;5;"5 Folds" 2;6;"5 Folds"
3;2;"Stratified" 3;2;"Stratified"
3;5;"False" 3;6;"False"
4;2;"Discretized" 4;2;"Discretized"
4;5;"False" 4;6;"False"
5;2;"Seeds" 5;2;"Seeds"
5;5;"[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]" 5;6;"[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]"
6;1;"Dataset" 6;1;"Dataset"
6;2;"Samples" 6;2;"Samples"
6;3;"Features" 6;3;"Features"
6;4;"Classes" 6;4;"Continuous"
6;5;"Balance" 6;5;"Classes"
6;6;"Balance"
7;1;"balance-scale" 7;1;"balance-scale"
7;2;"625" 7;2;"625"
7;3;"4" 7;3;"4"
7;4;"3" 7;4;"0"
7;5;" 7.84%/ 46.08%/ 46.08%" 7;5;"3"
7;6;" 7.84%/ 46.08%/ 46.08%"
8;1;"balloons" 8;1;"balloons"
8;2;"16" 8;2;"16"
8;3;"4" 8;3;"4"
8;4;"2" 8;4;"0"
8;5;"56.25%/ 43.75%" 8;5;"2"
8;6;"56.25%/ 43.75%"

View File

@@ -1,6 +1,6 @@
Datasets used in benchmark ver. 0.2.0 Datasets used in benchmark ver. 0.2.0
Dataset Sampl. Feat. Cls Balance Dataset Sampl. Feat. Cont Cls Balance
============================== ====== ===== === ============================================================ ============================== ====== ===== ==== === ============================================================
balance-scale 625 4 3 7.84%/ 46.08%/ 46.08% balance-scale 625 4 0 3 7.84%/ 46.08%/ 46.08%
balloons 16 4 2 56.25%/ 43.75% balloons 16 4 0 2 56.25%/ 43.75%