mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-17 00:15:55 +00:00
Fix tests
This commit is contained in:
@@ -28,6 +28,12 @@ class DatasetsArff:
|
||||
def folder():
|
||||
return "datasets"
|
||||
|
||||
@staticmethod
|
||||
def get_range_features(X, c_features):
|
||||
if c_features.strip() == "all":
|
||||
return list(range(X.shape[1]))
|
||||
return json.loads(c_features)
|
||||
|
||||
def load(self, name, class_name):
|
||||
file_name = os.path.join(self.folder(), self.dataset_names(name))
|
||||
data = arff.loadarff(file_name)
|
||||
@@ -51,6 +57,10 @@ class DatasetsTanveer:
|
||||
def folder():
|
||||
return "data"
|
||||
|
||||
@staticmethod
|
||||
def get_range_features(X, name):
|
||||
return []
|
||||
|
||||
def load(self, name, *args):
|
||||
file_name = os.path.join(self.folder(), self.dataset_names(name))
|
||||
data = pd.read_csv(
|
||||
@@ -76,6 +86,10 @@ class DatasetsSurcov:
|
||||
def folder():
|
||||
return "datasets"
|
||||
|
||||
@staticmethod
|
||||
def get_range_features(X, name):
|
||||
return []
|
||||
|
||||
def load(self, name, *args):
|
||||
file_name = os.path.join(self.folder(), self.dataset_names(name))
|
||||
data = pd.read_csv(
|
||||
@@ -179,16 +193,13 @@ class Datasets:
|
||||
}
|
||||
|
||||
def load(self, name, dataframe=False):
|
||||
def get_range_features(X, name):
|
||||
c_features = self.continuous_features[name]
|
||||
if c_features.strip() == "all":
|
||||
return list(range(X.shape[1]))
|
||||
return json.loads(c_features)
|
||||
|
||||
try:
|
||||
class_name = self.class_names[self.data_sets.index(name)]
|
||||
X, y = self.dataset.load(name, class_name)
|
||||
self.continuous_features_dataset = get_range_features(X, name)
|
||||
self.continuous_features_dataset = self.dataset.get_range_features(
|
||||
X, self.continuous_features[name]
|
||||
)
|
||||
if self.discretize:
|
||||
X = self.discretize_dataset(X, y)
|
||||
self.build_states(name, X)
|
||||
|
@@ -7,5 +7,4 @@ stratified=0
|
||||
source_data=Tanveer
|
||||
seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
||||
discretize=0
|
||||
ignore_nan=0
|
||||
|
||||
ignore_nan=0
|
@@ -7,4 +7,4 @@ stratified=0
|
||||
source_data=Tanveer
|
||||
seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
||||
discretize=0
|
||||
ignore_nan=0
|
||||
ignore_nan=0
|
@@ -1,4 +1,3 @@
|
||||
import shutil
|
||||
from .TestBase import TestBase
|
||||
from ..Experiments import Randomized
|
||||
from ..Datasets import Datasets
|
||||
@@ -17,10 +16,6 @@ class DatasetTest(TestBase):
|
||||
self.set_env(".env.dist")
|
||||
return super().tearDown()
|
||||
|
||||
@staticmethod
|
||||
def set_env(env):
|
||||
shutil.copy(env, ".env")
|
||||
|
||||
def test_Randomized(self):
|
||||
expected = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]
|
||||
self.assertSequenceEqual(Randomized.seeds(), expected)
|
||||
|
@@ -8,10 +8,10 @@ class ExperimentTest(TestBase):
|
||||
def setUp(self):
|
||||
self.exp = self.build_exp()
|
||||
|
||||
def build_exp(self, hyperparams=False, grid=False):
|
||||
def build_exp(self, hyperparams=False, grid=False, model="STree"):
|
||||
params = {
|
||||
"score_name": "accuracy",
|
||||
"model_name": "STree",
|
||||
"model_name": model,
|
||||
"stratified": "0",
|
||||
"datasets": Datasets(),
|
||||
"hyperparams_dict": "{}",
|
||||
@@ -21,6 +21,7 @@ class ExperimentTest(TestBase):
|
||||
"title": "Test",
|
||||
"progress_bar": False,
|
||||
"folds": 2,
|
||||
"ignore_nan": False,
|
||||
}
|
||||
return Experiment(**params)
|
||||
|
||||
@@ -31,6 +32,7 @@ class ExperimentTest(TestBase):
|
||||
],
|
||||
".",
|
||||
)
|
||||
self.set_env(".env.dist")
|
||||
return super().tearDown()
|
||||
|
||||
def test_build_hyperparams_file(self):
|
||||
@@ -89,7 +91,7 @@ class ExperimentTest(TestBase):
|
||||
def test_exception_n_fold_crossval(self):
|
||||
self.exp.do_experiment()
|
||||
with self.assertRaises(ValueError):
|
||||
self.exp._n_fold_crossval([], [], {})
|
||||
self.exp._n_fold_crossval("", [], [], {})
|
||||
|
||||
def test_do_experiment(self):
|
||||
self.exp.do_experiment()
|
||||
@@ -131,3 +133,27 @@ class ExperimentTest(TestBase):
|
||||
):
|
||||
for key, value in expected_result.items():
|
||||
self.assertEqual(computed_result[key], value)
|
||||
|
||||
def test_build_fit_parameters(self):
|
||||
self.set_env(".env.arff")
|
||||
expected = {
|
||||
"state_names": {
|
||||
"sepallength": [0, 1, 2],
|
||||
"sepalwidth": [0, 1, 3, 4],
|
||||
"petallength": [0, 1, 2, 3],
|
||||
"petalwidth": [0, 1, 2, 3],
|
||||
},
|
||||
"features": [
|
||||
"sepallength",
|
||||
"sepalwidth",
|
||||
"petallength",
|
||||
"petalwidth",
|
||||
],
|
||||
}
|
||||
exp = self.build_exp(model="TAN")
|
||||
X, y = exp.datasets.load("iris")
|
||||
computed = exp._build_fit_params("iris")
|
||||
for key, value in expected["state_names"].items():
|
||||
self.assertEqual(computed["state_names"][key], value)
|
||||
for feature in expected["features"]:
|
||||
self.assertIn(feature, computed["features"])
|
||||
|
@@ -4,6 +4,7 @@ import pathlib
|
||||
import sys
|
||||
import csv
|
||||
import unittest
|
||||
import shutil
|
||||
from importlib import import_module
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
@@ -19,6 +20,10 @@ class TestBase(unittest.TestCase):
|
||||
self.stree_version = "1.2.4"
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def set_env(env):
|
||||
shutil.copy(env, ".env")
|
||||
|
||||
def remove_files(self, files, folder):
|
||||
for file_name in files:
|
||||
file_name = os.path.join(folder, file_name)
|
||||
|
@@ -180,6 +180,7 @@ class UtilTest(TestBase):
|
||||
"source_data": "Tanveer",
|
||||
"seeds": "[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]",
|
||||
"discretize": "0",
|
||||
"ignore_nan": "0",
|
||||
}
|
||||
computed = EnvData().load()
|
||||
self.assertDictEqual(computed, expected)
|
||||
@@ -191,8 +192,16 @@ class UtilTest(TestBase):
|
||||
"n_folds": 5,
|
||||
"model": "STree",
|
||||
"stratified": "0",
|
||||
"ignore_nan": "0",
|
||||
}
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument(
|
||||
"--ignore-nan",
|
||||
action=EnvDefault,
|
||||
envvar="ignore_nan",
|
||||
required=True,
|
||||
help="Ignore nan results",
|
||||
)
|
||||
ap.add_argument(
|
||||
"-s",
|
||||
"--score",
|
||||
|
@@ -1,2 +1,2 @@
|
||||
iris,class
|
||||
wine,class
|
||||
iris,class,all
|
||||
wine,class,[0, 1]
|
||||
|
@@ -1,25 +1,28 @@
|
||||
1;1;"Datasets used in benchmark ver. 0.2.0"
|
||||
1;1;"Datasets used in benchmark ver. 0.4.0"
|
||||
2;1;" Default score accuracy"
|
||||
2;2;"Cross validation"
|
||||
2;5;"5 Folds"
|
||||
2;6;"5 Folds"
|
||||
3;2;"Stratified"
|
||||
3;5;"False"
|
||||
3;6;"False"
|
||||
4;2;"Discretized"
|
||||
4;5;"False"
|
||||
4;6;"False"
|
||||
5;2;"Seeds"
|
||||
5;5;"[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]"
|
||||
5;6;"[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]"
|
||||
6;1;"Dataset"
|
||||
6;2;"Samples"
|
||||
6;3;"Features"
|
||||
6;4;"Classes"
|
||||
6;5;"Balance"
|
||||
6;4;"Continuous"
|
||||
6;5;"Classes"
|
||||
6;6;"Balance"
|
||||
7;1;"balance-scale"
|
||||
7;2;"625"
|
||||
7;3;"4"
|
||||
7;4;"3"
|
||||
7;5;" 7.84%/ 46.08%/ 46.08%"
|
||||
7;4;"0"
|
||||
7;5;"3"
|
||||
7;6;" 7.84%/ 46.08%/ 46.08%"
|
||||
8;1;"balloons"
|
||||
8;2;"16"
|
||||
8;3;"4"
|
||||
8;4;"2"
|
||||
8;5;"56.25%/ 43.75%"
|
||||
8;4;"0"
|
||||
8;5;"2"
|
||||
8;6;"56.25%/ 43.75%"
|
||||
|
@@ -1,6 +1,6 @@
|
||||
[94mDatasets used in benchmark ver. 0.2.0
|
||||
|
||||
Dataset Sampl. Feat. Cls Balance
|
||||
============================== ====== ===== === ============================================================
|
||||
[96mbalance-scale 625 4 3 7.84%/ 46.08%/ 46.08%
|
||||
[94mballoons 16 4 2 56.25%/ 43.75%
|
||||
Dataset Sampl. Feat. Cont Cls Balance
|
||||
============================== ====== ===== ==== === ============================================================
|
||||
[96mbalance-scale 625 4 0 3 7.84%/ 46.08%/ 46.08%
|
||||
[94mballoons 16 4 0 2 56.25%/ 43.75%
|
||||
|
Reference in New Issue
Block a user