From 7bbfb4b68e9376859d7a35b1f33fd5921f51bf43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Sat, 23 Apr 2022 10:54:51 +0200 Subject: [PATCH] 100% coverage in Models --- benchmark/Models.py | 52 +++++++++++---------- benchmark/tests/Models_test.py | 84 ++++++++++++++++++++++++++++++++++ benchmark/tests/Util_test.py | 5 -- benchmark/tests/__init__.py | 3 +- 4 files changed, 115 insertions(+), 29 deletions(-) create mode 100644 benchmark/tests/Models_test.py diff --git a/benchmark/Models.py b/benchmark/Models.py index ef5ab72..627142d 100644 --- a/benchmark/Models.py +++ b/benchmark/Models.py @@ -15,28 +15,41 @@ class Models: @staticmethod def get_model(name, random_state=None): if name == "STree": - return Stree() + return Stree(random_state=random_state) if name == "Cart": - return DecisionTreeClassifier() + return DecisionTreeClassifier(random_state=random_state) if name == "ExtraTree": - return ExtraTreeClassifier() + return ExtraTreeClassifier(random_state=random_state) if name == "Wodt": - return Wodt() + return Wodt(random_state=random_state) if name == "SVC": - return SVC() + return SVC(random_state=random_state) if name == "ODTE": - return Odte(base_estimator=Stree()) + return Odte( + base_estimator=Stree(random_state=random_state), + random_state=random_state, + ) if name == "BaggingStree": clf = Stree(random_state=random_state) - return BaggingClassifier(base_estimator=clf) + return BaggingClassifier( + base_estimator=clf, random_state=random_state + ) if name == "BaggingWodt": clf = Wodt(random_state=random_state) - return BaggingClassifier(base_estimator=clf) + return BaggingClassifier( + base_estimator=clf, random_state=random_state + ) if name == "AdaBoostStree": - clf = Stree(random_state=random_state) - return AdaBoostClassifier(base_estimator=clf) + clf = Stree( + random_state=random_state, + ) + return AdaBoostClassifier( + base_estimator=clf, + algorithm="SAMME", + random_state=random_state, + ) if name == "RandomForest": - return RandomForestClassifier() + return RandomForestClassifier(random_state=random_state) msg = f"No model recognized {name}" if name in ("Stree", "stree"): msg += ", did you mean STree?" @@ -55,18 +68,11 @@ class Models: leaves = result.get_n_leaves() depth = 0 elif name.startswith("Bagging") or name.startswith("AdaBoost"): - if hasattr(result.base_estimator_, "nodes_leaves"): - nodes, leaves = list( - zip(*[x.nodes_leaves() for x in result.estimators_]) - ) - nodes, leaves = mean(nodes), mean(leaves) - depth = mean([x.depth_ for x in result.estimators_]) - elif hasattr(result.base_estimator_, "tree_"): - nodes = mean([x.tree_.node_count for x in result.estimators_]) - leaves = mean([x.get_n_leaves() for x in result.estimators_]) - depth = mean([x.get_depth() for x in result.estimators_]) - else: - nodes = leaves = depth = 0 + nodes, leaves = list( + zip(*[x.nodes_leaves() for x in result.estimators_]) + ) + nodes, leaves = mean(nodes), mean(leaves) + depth = mean([x.depth_ for x in result.estimators_]) elif name == "RandomForest": leaves = mean([x.get_n_leaves() for x in result.estimators_]) depth = mean([x.get_depth() for x in result.estimators_]) diff --git a/benchmark/tests/Models_test.py b/benchmark/tests/Models_test.py new file mode 100644 index 0000000..de54763 --- /dev/null +++ b/benchmark/tests/Models_test.py @@ -0,0 +1,84 @@ +import unittest +import warnings +from sklearn.exceptions import ConvergenceWarning +from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier +from sklearn.ensemble import ( + RandomForestClassifier, + BaggingClassifier, + AdaBoostClassifier, +) +from sklearn.svm import SVC +from sklearn.datasets import load_wine +from stree import Stree +from wodt import Wodt +from odte import Odte +from ..Models import Models + + +class ModelTest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def test_Models(self): + test = { + "STree": Stree, + "Wodt": Wodt, + "ODTE": Odte, + "Cart": DecisionTreeClassifier, + "SVC": SVC, + "RandomForest": RandomForestClassifier, + "ExtraTree": ExtraTreeClassifier, + } + for key, value in test.items(): + self.assertIsInstance(Models.get_model(key), value) + + def test_BaggingStree(self): + clf = Models.get_model("BaggingStree") + self.assertIsInstance(clf, BaggingClassifier) + clf_base = clf.base_estimator + self.assertIsInstance(clf_base, Stree) + + def test_BaggingWodt(self): + clf = Models.get_model("BaggingWodt") + self.assertIsInstance(clf, BaggingClassifier) + clf_base = clf.base_estimator + self.assertIsInstance(clf_base, Wodt) + + def test_AdaBoostStree(self): + clf = Models.get_model("AdaBoostStree") + self.assertIsInstance(clf, AdaBoostClassifier) + clf_base = clf.base_estimator + self.assertIsInstance(clf_base, Stree) + + def test_unknown_classifier(self): + with self.assertRaises(ValueError): + Models.get_model("unknown") + + def test_bogus_Stree(self): + with self.assertRaises(ValueError): + Models.get_model("Stree") + + def test_bogus_Odte(self): + with self.assertRaises(ValueError): + Models.get_model("Odte") + + def test_get_complexity(self): + warnings.filterwarnings("ignore", category=ConvergenceWarning) + test = { + "STree": (11, 6, 4), + "Wodt": (303, 152, 50), + "ODTE": (7.84, 4.42, 3.36), + "Cart": (23, 12, 5), + "SVC": (0, 0, 0), + "RandomForest": (21.3, 11, 5.26), + "ExtraTree": (0, 38, 0), + "AdaBoostStree": (12.25, 6.625, 4.75), + "BaggingStree": (8.4, 4.7, 3.5), + "BaggingWodt": (272, 136.5, 50), + } + X, y = load_wine(return_X_y=True) + for key, value in test.items(): + clf = Models.get_model(key, random_state=1) + clf.fit(X, y) + # print(key, Models.get_complexity(key, clf)) + self.assertSequenceEqual(Models.get_complexity(key, clf), value) diff --git a/benchmark/tests/Util_test.py b/benchmark/tests/Util_test.py index 04c00d4..4910cc8 100644 --- a/benchmark/tests/Util_test.py +++ b/benchmark/tests/Util_test.py @@ -6,11 +6,6 @@ from ..Utils import Folders, Files, Symbols, TextColor, EnvData, EnvDefault class UtilTest(unittest.TestCase): - def __init__(self, *args, **kwargs): - self._random_state = 1 - self._kernels = ["liblinear", "linear", "rbf", "poly", "sigmoid"] - super().__init__(*args, **kwargs) - def test_Folders(self): self.assertEqual("results", Folders.results) self.assertEqual("hidden_results", Folders.hidden_results) diff --git a/benchmark/tests/__init__.py b/benchmark/tests/__init__.py index 8523cc1..897ae22 100644 --- a/benchmark/tests/__init__.py +++ b/benchmark/tests/__init__.py @@ -1,3 +1,4 @@ from .Util_test import UtilTest +from .Models_test import ModelTest -all = ["UtilTest"] +all = ["UtilTest", "ModelTest"]