100% coverage in Models

This commit is contained in:
2022-04-23 10:54:51 +02:00
parent 5b6cb17edc
commit 7bbfb4b68e
4 changed files with 115 additions and 29 deletions

View File

@@ -15,28 +15,41 @@ class Models:
@staticmethod
def get_model(name, random_state=None):
if name == "STree":
return Stree()
return Stree(random_state=random_state)
if name == "Cart":
return DecisionTreeClassifier()
return DecisionTreeClassifier(random_state=random_state)
if name == "ExtraTree":
return ExtraTreeClassifier()
return ExtraTreeClassifier(random_state=random_state)
if name == "Wodt":
return Wodt()
return Wodt(random_state=random_state)
if name == "SVC":
return SVC()
return SVC(random_state=random_state)
if name == "ODTE":
return Odte(base_estimator=Stree())
return Odte(
base_estimator=Stree(random_state=random_state),
random_state=random_state,
)
if name == "BaggingStree":
clf = Stree(random_state=random_state)
return BaggingClassifier(base_estimator=clf)
return BaggingClassifier(
base_estimator=clf, random_state=random_state
)
if name == "BaggingWodt":
clf = Wodt(random_state=random_state)
return BaggingClassifier(base_estimator=clf)
return BaggingClassifier(
base_estimator=clf, random_state=random_state
)
if name == "AdaBoostStree":
clf = Stree(random_state=random_state)
return AdaBoostClassifier(base_estimator=clf)
clf = Stree(
random_state=random_state,
)
return AdaBoostClassifier(
base_estimator=clf,
algorithm="SAMME",
random_state=random_state,
)
if name == "RandomForest":
return RandomForestClassifier()
return RandomForestClassifier(random_state=random_state)
msg = f"No model recognized {name}"
if name in ("Stree", "stree"):
msg += ", did you mean STree?"
@@ -55,18 +68,11 @@ class Models:
leaves = result.get_n_leaves()
depth = 0
elif name.startswith("Bagging") or name.startswith("AdaBoost"):
if hasattr(result.base_estimator_, "nodes_leaves"):
nodes, leaves = list(
zip(*[x.nodes_leaves() for x in result.estimators_])
)
nodes, leaves = mean(nodes), mean(leaves)
depth = mean([x.depth_ for x in result.estimators_])
elif hasattr(result.base_estimator_, "tree_"):
nodes = mean([x.tree_.node_count for x in result.estimators_])
leaves = mean([x.get_n_leaves() for x in result.estimators_])
depth = mean([x.get_depth() for x in result.estimators_])
else:
nodes = leaves = depth = 0
elif name == "RandomForest":
leaves = mean([x.get_n_leaves() for x in result.estimators_])
depth = mean([x.get_depth() for x in result.estimators_])

View File

@@ -0,0 +1,84 @@
import unittest
import warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
from sklearn.ensemble import (
RandomForestClassifier,
BaggingClassifier,
AdaBoostClassifier,
)
from sklearn.svm import SVC
from sklearn.datasets import load_wine
from stree import Stree
from wodt import Wodt
from odte import Odte
from ..Models import Models
class ModelTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def test_Models(self):
test = {
"STree": Stree,
"Wodt": Wodt,
"ODTE": Odte,
"Cart": DecisionTreeClassifier,
"SVC": SVC,
"RandomForest": RandomForestClassifier,
"ExtraTree": ExtraTreeClassifier,
}
for key, value in test.items():
self.assertIsInstance(Models.get_model(key), value)
def test_BaggingStree(self):
clf = Models.get_model("BaggingStree")
self.assertIsInstance(clf, BaggingClassifier)
clf_base = clf.base_estimator
self.assertIsInstance(clf_base, Stree)
def test_BaggingWodt(self):
clf = Models.get_model("BaggingWodt")
self.assertIsInstance(clf, BaggingClassifier)
clf_base = clf.base_estimator
self.assertIsInstance(clf_base, Wodt)
def test_AdaBoostStree(self):
clf = Models.get_model("AdaBoostStree")
self.assertIsInstance(clf, AdaBoostClassifier)
clf_base = clf.base_estimator
self.assertIsInstance(clf_base, Stree)
def test_unknown_classifier(self):
with self.assertRaises(ValueError):
Models.get_model("unknown")
def test_bogus_Stree(self):
with self.assertRaises(ValueError):
Models.get_model("Stree")
def test_bogus_Odte(self):
with self.assertRaises(ValueError):
Models.get_model("Odte")
def test_get_complexity(self):
warnings.filterwarnings("ignore", category=ConvergenceWarning)
test = {
"STree": (11, 6, 4),
"Wodt": (303, 152, 50),
"ODTE": (7.84, 4.42, 3.36),
"Cart": (23, 12, 5),
"SVC": (0, 0, 0),
"RandomForest": (21.3, 11, 5.26),
"ExtraTree": (0, 38, 0),
"AdaBoostStree": (12.25, 6.625, 4.75),
"BaggingStree": (8.4, 4.7, 3.5),
"BaggingWodt": (272, 136.5, 50),
}
X, y = load_wine(return_X_y=True)
for key, value in test.items():
clf = Models.get_model(key, random_state=1)
clf.fit(X, y)
# print(key, Models.get_complexity(key, clf))
self.assertSequenceEqual(Models.get_complexity(key, clf), value)

View File

@@ -6,11 +6,6 @@ from ..Utils import Folders, Files, Symbols, TextColor, EnvData, EnvDefault
class UtilTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
self._random_state = 1
self._kernels = ["liblinear", "linear", "rbf", "poly", "sigmoid"]
super().__init__(*args, **kwargs)
def test_Folders(self):
self.assertEqual("results", Folders.results)
self.assertEqual("hidden_results", Folders.hidden_results)

View File

@@ -1,3 +1,4 @@
from .Util_test import UtilTest
from .Models_test import ModelTest
all = ["UtilTest"]
all = ["UtilTest", "ModelTest"]