diff --git a/benchmark/Datasets.py b/benchmark/Datasets.py index 5e04c73..ed7356e 100644 --- a/benchmark/Datasets.py +++ b/benchmark/Datasets.py @@ -28,6 +28,12 @@ class DatasetsArff: def folder(): return "datasets" + @staticmethod + def get_range_features(X, c_features): + if c_features.strip() == "all": + return list(range(X.shape[1])) + return json.loads(c_features) + def load(self, name, class_name): file_name = os.path.join(self.folder(), self.dataset_names(name)) data = arff.loadarff(file_name) @@ -51,6 +57,10 @@ class DatasetsTanveer: def folder(): return "data" + @staticmethod + def get_range_features(X, name): + return [] + def load(self, name, *args): file_name = os.path.join(self.folder(), self.dataset_names(name)) data = pd.read_csv( @@ -76,6 +86,10 @@ class DatasetsSurcov: def folder(): return "datasets" + @staticmethod + def get_range_features(X, name): + return [] + def load(self, name, *args): file_name = os.path.join(self.folder(), self.dataset_names(name)) data = pd.read_csv( @@ -179,16 +193,13 @@ class Datasets: } def load(self, name, dataframe=False): - def get_range_features(X, name): - c_features = self.continuous_features[name] - if c_features.strip() == "all": - return list(range(X.shape[1])) - return json.loads(c_features) try: class_name = self.class_names[self.data_sets.index(name)] X, y = self.dataset.load(name, class_name) - self.continuous_features_dataset = get_range_features(X, name) + self.continuous_features_dataset = self.dataset.get_range_features( + X, self.continuous_features[name] + ) if self.discretize: X = self.discretize_dataset(X, y) self.build_states(name, X) diff --git a/benchmark/tests/.env b/benchmark/tests/.env index f37499f..82e6f71 100644 --- a/benchmark/tests/.env +++ b/benchmark/tests/.env @@ -7,5 +7,4 @@ stratified=0 source_data=Tanveer seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] discretize=0 -ignore_nan=0 - +ignore_nan=0 \ No newline at end of file diff --git a/benchmark/tests/.env.dist b/benchmark/tests/.env.dist index f1b718a..82e6f71 100644 --- a/benchmark/tests/.env.dist +++ b/benchmark/tests/.env.dist @@ -7,4 +7,4 @@ stratified=0 source_data=Tanveer seeds=[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] discretize=0 -ignore_nan=0 +ignore_nan=0 \ No newline at end of file diff --git a/benchmark/tests/Dataset_test.py b/benchmark/tests/Dataset_test.py index 3c1ca49..ce5521c 100644 --- a/benchmark/tests/Dataset_test.py +++ b/benchmark/tests/Dataset_test.py @@ -1,4 +1,3 @@ -import shutil from .TestBase import TestBase from ..Experiments import Randomized from ..Datasets import Datasets @@ -17,10 +16,6 @@ class DatasetTest(TestBase): self.set_env(".env.dist") return super().tearDown() - @staticmethod - def set_env(env): - shutil.copy(env, ".env") - def test_Randomized(self): expected = [57, 31, 1714, 17, 23, 79, 83, 97, 7, 1] self.assertSequenceEqual(Randomized.seeds(), expected) diff --git a/benchmark/tests/Experiment_test.py b/benchmark/tests/Experiment_test.py index 0f8ffad..722052b 100644 --- a/benchmark/tests/Experiment_test.py +++ b/benchmark/tests/Experiment_test.py @@ -8,10 +8,10 @@ class ExperimentTest(TestBase): def setUp(self): self.exp = self.build_exp() - def build_exp(self, hyperparams=False, grid=False): + def build_exp(self, hyperparams=False, grid=False, model="STree"): params = { "score_name": "accuracy", - "model_name": "STree", + "model_name": model, "stratified": "0", "datasets": Datasets(), "hyperparams_dict": "{}", @@ -21,6 +21,7 @@ class ExperimentTest(TestBase): "title": "Test", "progress_bar": False, "folds": 2, + "ignore_nan": False, } return Experiment(**params) @@ -31,6 +32,7 @@ class ExperimentTest(TestBase): ], ".", ) + self.set_env(".env.dist") return super().tearDown() def test_build_hyperparams_file(self): @@ -89,7 +91,7 @@ class ExperimentTest(TestBase): def test_exception_n_fold_crossval(self): self.exp.do_experiment() with self.assertRaises(ValueError): - self.exp._n_fold_crossval([], [], {}) + self.exp._n_fold_crossval("", [], [], {}) def test_do_experiment(self): self.exp.do_experiment() @@ -131,3 +133,27 @@ class ExperimentTest(TestBase): ): for key, value in expected_result.items(): self.assertEqual(computed_result[key], value) + + def test_build_fit_parameters(self): + self.set_env(".env.arff") + expected = { + "state_names": { + "sepallength": [0, 1, 2], + "sepalwidth": [0, 1, 3, 4], + "petallength": [0, 1, 2, 3], + "petalwidth": [0, 1, 2, 3], + }, + "features": [ + "sepallength", + "sepalwidth", + "petallength", + "petalwidth", + ], + } + exp = self.build_exp(model="TAN") + X, y = exp.datasets.load("iris") + computed = exp._build_fit_params("iris") + for key, value in expected["state_names"].items(): + self.assertEqual(computed["state_names"][key], value) + for feature in expected["features"]: + self.assertIn(feature, computed["features"]) diff --git a/benchmark/tests/TestBase.py b/benchmark/tests/TestBase.py index 96d5e7d..b25bc81 100644 --- a/benchmark/tests/TestBase.py +++ b/benchmark/tests/TestBase.py @@ -4,6 +4,7 @@ import pathlib import sys import csv import unittest +import shutil from importlib import import_module from io import StringIO from unittest.mock import patch @@ -19,6 +20,10 @@ class TestBase(unittest.TestCase): self.stree_version = "1.2.4" super().__init__(*args, **kwargs) + @staticmethod + def set_env(env): + shutil.copy(env, ".env") + def remove_files(self, files, folder): for file_name in files: file_name = os.path.join(folder, file_name) diff --git a/benchmark/tests/Util_test.py b/benchmark/tests/Util_test.py index 8ca7b33..ad291fa 100644 --- a/benchmark/tests/Util_test.py +++ b/benchmark/tests/Util_test.py @@ -180,6 +180,7 @@ class UtilTest(TestBase): "source_data": "Tanveer", "seeds": "[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]", "discretize": "0", + "ignore_nan": "0", } computed = EnvData().load() self.assertDictEqual(computed, expected) @@ -191,8 +192,16 @@ class UtilTest(TestBase): "n_folds": 5, "model": "STree", "stratified": "0", + "ignore_nan": "0", } ap = argparse.ArgumentParser() + ap.add_argument( + "--ignore-nan", + action=EnvDefault, + envvar="ignore_nan", + required=True, + help="Ignore nan results", + ) ap.add_argument( "-s", "--score", diff --git a/benchmark/tests/datasets/all.txt b/benchmark/tests/datasets/all.txt index ddf732a..48584fd 100644 --- a/benchmark/tests/datasets/all.txt +++ b/benchmark/tests/datasets/all.txt @@ -1,2 +1,2 @@ -iris,class -wine,class +iris,class,all +wine,class,[0, 1] diff --git a/benchmark/tests/test_files/exreport_excel_Datasets.test b/benchmark/tests/test_files/exreport_excel_Datasets.test index 5c2f35a..054b981 100644 --- a/benchmark/tests/test_files/exreport_excel_Datasets.test +++ b/benchmark/tests/test_files/exreport_excel_Datasets.test @@ -1,25 +1,28 @@ -1;1;"Datasets used in benchmark ver. 0.2.0" +1;1;"Datasets used in benchmark ver. 0.4.0" 2;1;" Default score accuracy" 2;2;"Cross validation" -2;5;"5 Folds" +2;6;"5 Folds" 3;2;"Stratified" -3;5;"False" +3;6;"False" 4;2;"Discretized" -4;5;"False" +4;6;"False" 5;2;"Seeds" -5;5;"[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]" +5;6;"[57, 31, 1714, 17, 23, 79, 83, 97, 7, 1]" 6;1;"Dataset" 6;2;"Samples" 6;3;"Features" -6;4;"Classes" -6;5;"Balance" +6;4;"Continuous" +6;5;"Classes" +6;6;"Balance" 7;1;"balance-scale" 7;2;"625" 7;3;"4" -7;4;"3" -7;5;" 7.84%/ 46.08%/ 46.08%" +7;4;"0" +7;5;"3" +7;6;" 7.84%/ 46.08%/ 46.08%" 8;1;"balloons" 8;2;"16" 8;3;"4" -8;4;"2" -8;5;"56.25%/ 43.75%" +8;4;"0" +8;5;"2" +8;6;"56.25%/ 43.75%" diff --git a/benchmark/tests/test_files/report_datasets.test b/benchmark/tests/test_files/report_datasets.test index 16c7bd7..3fa0aeb 100644 --- a/benchmark/tests/test_files/report_datasets.test +++ b/benchmark/tests/test_files/report_datasets.test @@ -1,6 +1,6 @@ Datasets used in benchmark ver. 0.2.0 -Dataset Sampl. Feat. Cls Balance -============================== ====== ===== === ============================================================ -balance-scale 625 4 3 7.84%/ 46.08%/ 46.08% -balloons 16 4 2 56.25%/ 43.75% +Dataset Sampl. Feat. Cont Cls Balance +============================== ====== ===== ==== === ============================================================ +balance-scale 625 4 0 3 7.84%/ 46.08%/ 46.08% +balloons 16 4 0 2 56.25%/ 43.75%