mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-18 08:55:53 +00:00
Fix tests
This commit is contained in:
@@ -28,6 +28,12 @@ class DatasetsArff:
|
|||||||
def folder():
|
def folder():
|
||||||
return "datasets"
|
return "datasets"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_range_features(X, c_features):
|
||||||
|
if c_features.strip() == "all":
|
||||||
|
return list(range(X.shape[1]))
|
||||||
|
return json.loads(c_features)
|
||||||
|
|
||||||
def load(self, name, class_name):
|
def load(self, name, class_name):
|
||||||
file_name = os.path.join(self.folder(), self.dataset_names(name))
|
file_name = os.path.join(self.folder(), self.dataset_names(name))
|
||||||
data = arff.loadarff(file_name)
|
data = arff.loadarff(file_name)
|
||||||
@@ -51,6 +57,10 @@ class DatasetsTanveer:
|
|||||||
def folder():
|
def folder():
|
||||||
return "data"
|
return "data"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_range_features(X, name):
|
||||||
|
return []
|
||||||
|
|
||||||
def load(self, name, *args):
|
def load(self, name, *args):
|
||||||
file_name = os.path.join(self.folder(), self.dataset_names(name))
|
file_name = os.path.join(self.folder(), self.dataset_names(name))
|
||||||
data = pd.read_csv(
|
data = pd.read_csv(
|
||||||
@@ -76,6 +86,10 @@ class DatasetsSurcov:
|
|||||||
def folder():
|
def folder():
|
||||||
return "datasets"
|
return "datasets"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_range_features(X, name):
|
||||||
|
return []
|
||||||
|
|
||||||
def load(self, name, *args):
|
def load(self, name, *args):
|
||||||
file_name = os.path.join(self.folder(), self.dataset_names(name))
|
file_name = os.path.join(self.folder(), self.dataset_names(name))
|
||||||
data = pd.read_csv(
|
data = pd.read_csv(
|
||||||
@@ -179,16 +193,13 @@ class Datasets:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def load(self, name, dataframe=False):
|
def load(self, name, dataframe=False):
|
||||||
def get_range_features(X, name):
|
|
||||||
c_features = self.continuous_features[name]
|
|
||||||
if c_features.strip() == "all":
|
|
||||||
return list(range(X.shape[1]))
|
|
||||||
return json.loads(c_features)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
class_name = self.class_names[self.data_sets.index(name)]
|
class_name = self.class_names[self.data_sets.index(name)]
|
||||||
X, y = self.dataset.load(name, class_name)
|
X, y = self.dataset.load(name, class_name)
|
||||||
self.continuous_features_dataset = get_range_features(X, name)
|
self.continuous_features_dataset = self.dataset.get_range_features(
|
||||||
|
X, self.continuous_features[name]
|
||||||
|
)
|
||||||
if self.discretize:
|
if self.discretize:
|
||||||
X = self.discretize_dataset(X, y)
|
X = self.discretize_dataset(X, y)
|
||||||
self.build_states(name, X)
|
self.build_states(name, X)
|
||||||
|
@@ -8,4 +8,3 @@ source_data=Tanveer
|
|||||||
seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
||||||
discretize=0
|
discretize=0
|
||||||
ignore_nan=0
|
ignore_nan=0
|
||||||
|
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
import shutil
|
|
||||||
from .TestBase import TestBase
|
from .TestBase import TestBase
|
||||||
from ..Experiments import Randomized
|
from ..Experiments import Randomized
|
||||||
from ..Datasets import Datasets
|
from ..Datasets import Datasets
|
||||||
@@ -17,10 +16,6 @@ class DatasetTest(TestBase):
|
|||||||
self.set_env(".env.dist")
|
self.set_env(".env.dist")
|
||||||
return super().tearDown()
|
return super().tearDown()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def set_env(env):
|
|
||||||
shutil.copy(env, ".env")
|
|
||||||
|
|
||||||
def test_Randomized(self):
|
def test_Randomized(self):
|
||||||
expected = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
expected = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
||||||
self.assertSequenceEqual(Randomized.seeds(), expected)
|
self.assertSequenceEqual(Randomized.seeds(), expected)
|
||||||
|
@@ -8,10 +8,10 @@ class ExperimentTest(TestBase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.exp = self.build_exp()
|
self.exp = self.build_exp()
|
||||||
|
|
||||||
def build_exp(self, hyperparams=False, grid=False):
|
def build_exp(self, hyperparams=False, grid=False, model="STree"):
|
||||||
params = {
|
params = {
|
||||||
"score_name": "accuracy",
|
"score_name": "accuracy",
|
||||||
"model_name": "STree",
|
"model_name": model,
|
||||||
"stratified": "0",
|
"stratified": "0",
|
||||||
"datasets": Datasets(),
|
"datasets": Datasets(),
|
||||||
"hyperparams_dict": "{}",
|
"hyperparams_dict": "{}",
|
||||||
@@ -21,6 +21,7 @@ class ExperimentTest(TestBase):
|
|||||||
"title": "Test",
|
"title": "Test",
|
||||||
"progress_bar": False,
|
"progress_bar": False,
|
||||||
"folds": 2,
|
"folds": 2,
|
||||||
|
"ignore_nan": False,
|
||||||
}
|
}
|
||||||
return Experiment(**params)
|
return Experiment(**params)
|
||||||
|
|
||||||
@@ -31,6 +32,7 @@ class ExperimentTest(TestBase):
|
|||||||
],
|
],
|
||||||
".",
|
".",
|
||||||
)
|
)
|
||||||
|
self.set_env(".env.dist")
|
||||||
return super().tearDown()
|
return super().tearDown()
|
||||||
|
|
||||||
def test_build_hyperparams_file(self):
|
def test_build_hyperparams_file(self):
|
||||||
@@ -89,7 +91,7 @@ class ExperimentTest(TestBase):
|
|||||||
def test_exception_n_fold_crossval(self):
|
def test_exception_n_fold_crossval(self):
|
||||||
self.exp.do_experiment()
|
self.exp.do_experiment()
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
self.exp._n_fold_crossval([], [], {})
|
self.exp._n_fold_crossval("", [], [], {})
|
||||||
|
|
||||||
def test_do_experiment(self):
|
def test_do_experiment(self):
|
||||||
self.exp.do_experiment()
|
self.exp.do_experiment()
|
||||||
@@ -131,3 +133,27 @@ class ExperimentTest(TestBase):
|
|||||||
):
|
):
|
||||||
for key, value in expected_result.items():
|
for key, value in expected_result.items():
|
||||||
self.assertEqual(computed_result[key], value)
|
self.assertEqual(computed_result[key], value)
|
||||||
|
|
||||||
|
def test_build_fit_parameters(self):
|
||||||
|
self.set_env(".env.arff")
|
||||||
|
expected = {
|
||||||
|
"state_names": {
|
||||||
|
"sepallength": [0, 1, 2],
|
||||||
|
"sepalwidth": [0, 1, 3, 4],
|
||||||
|
"petallength": [0, 1, 2, 3],
|
||||||
|
"petalwidth": [0, 1, 2, 3],
|
||||||
|
},
|
||||||
|
"features": [
|
||||||
|
"sepallength",
|
||||||
|
"sepalwidth",
|
||||||
|
"petallength",
|
||||||
|
"petalwidth",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
exp = self.build_exp(model="TAN")
|
||||||
|
X, y = exp.datasets.load("iris")
|
||||||
|
computed = exp._build_fit_params("iris")
|
||||||
|
for key, value in expected["state_names"].items():
|
||||||
|
self.assertEqual(computed["state_names"][key], value)
|
||||||
|
for feature in expected["features"]:
|
||||||
|
self.assertIn(feature, computed["features"])
|
||||||
|
@@ -4,6 +4,7 @@ import pathlib
|
|||||||
import sys
|
import sys
|
||||||
import csv
|
import csv
|
||||||
import unittest
|
import unittest
|
||||||
|
import shutil
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
@@ -19,6 +20,10 @@ class TestBase(unittest.TestCase):
|
|||||||
self.stree_version = "1.2.4"
|
self.stree_version = "1.2.4"
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_env(env):
|
||||||
|
shutil.copy(env, ".env")
|
||||||
|
|
||||||
def remove_files(self, files, folder):
|
def remove_files(self, files, folder):
|
||||||
for file_name in files:
|
for file_name in files:
|
||||||
file_name = os.path.join(folder, file_name)
|
file_name = os.path.join(folder, file_name)
|
||||||
|
@@ -180,6 +180,7 @@ class UtilTest(TestBase):
|
|||||||
"source_data": "Tanveer",
|
"source_data": "Tanveer",
|
||||||
"seeds": "[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]",
|
"seeds": "[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]",
|
||||||
"discretize": "0",
|
"discretize": "0",
|
||||||
|
"ignore_nan": "0",
|
||||||
}
|
}
|
||||||
computed = EnvData().load()
|
computed = EnvData().load()
|
||||||
self.assertDictEqual(computed, expected)
|
self.assertDictEqual(computed, expected)
|
||||||
@@ -191,8 +192,16 @@ class UtilTest(TestBase):
|
|||||||
"n_folds": 5,
|
"n_folds": 5,
|
||||||
"model": "STree",
|
"model": "STree",
|
||||||
"stratified": "0",
|
"stratified": "0",
|
||||||
|
"ignore_nan": "0",
|
||||||
}
|
}
|
||||||
ap = argparse.ArgumentParser()
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument(
|
||||||
|
"--ignore-nan",
|
||||||
|
action=EnvDefault,
|
||||||
|
envvar="ignore_nan",
|
||||||
|
required=True,
|
||||||
|
help="Ignore nan results",
|
||||||
|
)
|
||||||
ap.add_argument(
|
ap.add_argument(
|
||||||
"-s",
|
"-s",
|
||||||
"--score",
|
"--score",
|
||||||
|
@@ -1,2 +1,2 @@
|
|||||||
iris,class
|
iris,class,all
|
||||||
wine,class
|
wine,class,[0, 1]
|
||||||
|
@@ -1,25 +1,28 @@
|
|||||||
1;1;"Datasets used in benchmark ver. 0.2.0"
|
1;1;"Datasets used in benchmark ver. 0.4.0"
|
||||||
2;1;" Default score accuracy"
|
2;1;" Default score accuracy"
|
||||||
2;2;"Cross validation"
|
2;2;"Cross validation"
|
||||||
2;5;"5 Folds"
|
2;6;"5 Folds"
|
||||||
3;2;"Stratified"
|
3;2;"Stratified"
|
||||||
3;5;"False"
|
3;6;"False"
|
||||||
4;2;"Discretized"
|
4;2;"Discretized"
|
||||||
4;5;"False"
|
4;6;"False"
|
||||||
5;2;"Seeds"
|
5;2;"Seeds"
|
||||||
5;5;"[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]"
|
5;6;"[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]"
|
||||||
6;1;"Dataset"
|
6;1;"Dataset"
|
||||||
6;2;"Samples"
|
6;2;"Samples"
|
||||||
6;3;"Features"
|
6;3;"Features"
|
||||||
6;4;"Classes"
|
6;4;"Continuous"
|
||||||
6;5;"Balance"
|
6;5;"Classes"
|
||||||
|
6;6;"Balance"
|
||||||
7;1;"balance-scale"
|
7;1;"balance-scale"
|
||||||
7;2;"625"
|
7;2;"625"
|
||||||
7;3;"4"
|
7;3;"4"
|
||||||
7;4;"3"
|
7;4;"0"
|
||||||
7;5;" 7.84%/ 46.08%/ 46.08%"
|
7;5;"3"
|
||||||
|
7;6;" 7.84%/ 46.08%/ 46.08%"
|
||||||
8;1;"balloons"
|
8;1;"balloons"
|
||||||
8;2;"16"
|
8;2;"16"
|
||||||
8;3;"4"
|
8;3;"4"
|
||||||
8;4;"2"
|
8;4;"0"
|
||||||
8;5;"56.25%/ 43.75%"
|
8;5;"2"
|
||||||
|
8;6;"56.25%/ 43.75%"
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
[94mDatasets used in benchmark ver. 0.2.0
|
[94mDatasets used in benchmark ver. 0.2.0
|
||||||
|
|
||||||
Dataset Sampl. Feat. Cls Balance
|
Dataset Sampl. Feat. Cont Cls Balance
|
||||||
============================== ====== ===== === ============================================================
|
============================== ====== ===== ==== === ============================================================
|
||||||
[96mbalance-scale 625 4 3 7.84%/ 46.08%/ 46.08%
|
[96mbalance-scale 625 4 0 3 7.84%/ 46.08%/ 46.08%
|
||||||
[94mballoons 16 4 2 56.25%/ 43.75%
|
[94mballoons 16 4 0 2 56.25%/ 43.75%
|
||||||
|
Reference in New Issue
Block a user