mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-15 23:45:54 +00:00
100% coverage in Models
This commit is contained in:
@@ -15,28 +15,41 @@ class Models:
|
||||
@staticmethod
|
||||
def get_model(name, random_state=None):
|
||||
if name == "STree":
|
||||
return Stree()
|
||||
return Stree(random_state=random_state)
|
||||
if name == "Cart":
|
||||
return DecisionTreeClassifier()
|
||||
return DecisionTreeClassifier(random_state=random_state)
|
||||
if name == "ExtraTree":
|
||||
return ExtraTreeClassifier()
|
||||
return ExtraTreeClassifier(random_state=random_state)
|
||||
if name == "Wodt":
|
||||
return Wodt()
|
||||
return Wodt(random_state=random_state)
|
||||
if name == "SVC":
|
||||
return SVC()
|
||||
return SVC(random_state=random_state)
|
||||
if name == "ODTE":
|
||||
return Odte(base_estimator=Stree())
|
||||
return Odte(
|
||||
base_estimator=Stree(random_state=random_state),
|
||||
random_state=random_state,
|
||||
)
|
||||
if name == "BaggingStree":
|
||||
clf = Stree(random_state=random_state)
|
||||
return BaggingClassifier(base_estimator=clf)
|
||||
return BaggingClassifier(
|
||||
base_estimator=clf, random_state=random_state
|
||||
)
|
||||
if name == "BaggingWodt":
|
||||
clf = Wodt(random_state=random_state)
|
||||
return BaggingClassifier(base_estimator=clf)
|
||||
return BaggingClassifier(
|
||||
base_estimator=clf, random_state=random_state
|
||||
)
|
||||
if name == "AdaBoostStree":
|
||||
clf = Stree(random_state=random_state)
|
||||
return AdaBoostClassifier(base_estimator=clf)
|
||||
clf = Stree(
|
||||
random_state=random_state,
|
||||
)
|
||||
return AdaBoostClassifier(
|
||||
base_estimator=clf,
|
||||
algorithm="SAMME",
|
||||
random_state=random_state,
|
||||
)
|
||||
if name == "RandomForest":
|
||||
return RandomForestClassifier()
|
||||
return RandomForestClassifier(random_state=random_state)
|
||||
msg = f"No model recognized {name}"
|
||||
if name in ("Stree", "stree"):
|
||||
msg += ", did you mean STree?"
|
||||
@@ -55,18 +68,11 @@ class Models:
|
||||
leaves = result.get_n_leaves()
|
||||
depth = 0
|
||||
elif name.startswith("Bagging") or name.startswith("AdaBoost"):
|
||||
if hasattr(result.base_estimator_, "nodes_leaves"):
|
||||
nodes, leaves = list(
|
||||
zip(*[x.nodes_leaves() for x in result.estimators_])
|
||||
)
|
||||
nodes, leaves = mean(nodes), mean(leaves)
|
||||
depth = mean([x.depth_ for x in result.estimators_])
|
||||
elif hasattr(result.base_estimator_, "tree_"):
|
||||
nodes = mean([x.tree_.node_count for x in result.estimators_])
|
||||
leaves = mean([x.get_n_leaves() for x in result.estimators_])
|
||||
depth = mean([x.get_depth() for x in result.estimators_])
|
||||
else:
|
||||
nodes = leaves = depth = 0
|
||||
nodes, leaves = list(
|
||||
zip(*[x.nodes_leaves() for x in result.estimators_])
|
||||
)
|
||||
nodes, leaves = mean(nodes), mean(leaves)
|
||||
depth = mean([x.depth_ for x in result.estimators_])
|
||||
elif name == "RandomForest":
|
||||
leaves = mean([x.get_n_leaves() for x in result.estimators_])
|
||||
depth = mean([x.get_depth() for x in result.estimators_])
|
||||
|
84
benchmark/tests/Models_test.py
Normal file
84
benchmark/tests/Models_test.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import unittest
|
||||
import warnings
|
||||
from sklearn.exceptions import ConvergenceWarning
|
||||
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
|
||||
from sklearn.ensemble import (
|
||||
RandomForestClassifier,
|
||||
BaggingClassifier,
|
||||
AdaBoostClassifier,
|
||||
)
|
||||
from sklearn.svm import SVC
|
||||
from sklearn.datasets import load_wine
|
||||
from stree import Stree
|
||||
from wodt import Wodt
|
||||
from odte import Odte
|
||||
from ..Models import Models
|
||||
|
||||
|
||||
class ModelTest(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def test_Models(self):
|
||||
test = {
|
||||
"STree": Stree,
|
||||
"Wodt": Wodt,
|
||||
"ODTE": Odte,
|
||||
"Cart": DecisionTreeClassifier,
|
||||
"SVC": SVC,
|
||||
"RandomForest": RandomForestClassifier,
|
||||
"ExtraTree": ExtraTreeClassifier,
|
||||
}
|
||||
for key, value in test.items():
|
||||
self.assertIsInstance(Models.get_model(key), value)
|
||||
|
||||
def test_BaggingStree(self):
|
||||
clf = Models.get_model("BaggingStree")
|
||||
self.assertIsInstance(clf, BaggingClassifier)
|
||||
clf_base = clf.base_estimator
|
||||
self.assertIsInstance(clf_base, Stree)
|
||||
|
||||
def test_BaggingWodt(self):
|
||||
clf = Models.get_model("BaggingWodt")
|
||||
self.assertIsInstance(clf, BaggingClassifier)
|
||||
clf_base = clf.base_estimator
|
||||
self.assertIsInstance(clf_base, Wodt)
|
||||
|
||||
def test_AdaBoostStree(self):
|
||||
clf = Models.get_model("AdaBoostStree")
|
||||
self.assertIsInstance(clf, AdaBoostClassifier)
|
||||
clf_base = clf.base_estimator
|
||||
self.assertIsInstance(clf_base, Stree)
|
||||
|
||||
def test_unknown_classifier(self):
|
||||
with self.assertRaises(ValueError):
|
||||
Models.get_model("unknown")
|
||||
|
||||
def test_bogus_Stree(self):
|
||||
with self.assertRaises(ValueError):
|
||||
Models.get_model("Stree")
|
||||
|
||||
def test_bogus_Odte(self):
|
||||
with self.assertRaises(ValueError):
|
||||
Models.get_model("Odte")
|
||||
|
||||
def test_get_complexity(self):
|
||||
warnings.filterwarnings("ignore", category=ConvergenceWarning)
|
||||
test = {
|
||||
"STree": (11, 6, 4),
|
||||
"Wodt": (303, 152, 50),
|
||||
"ODTE": (7.84, 4.42, 3.36),
|
||||
"Cart": (23, 12, 5),
|
||||
"SVC": (0, 0, 0),
|
||||
"RandomForest": (21.3, 11, 5.26),
|
||||
"ExtraTree": (0, 38, 0),
|
||||
"AdaBoostStree": (12.25, 6.625, 4.75),
|
||||
"BaggingStree": (8.4, 4.7, 3.5),
|
||||
"BaggingWodt": (272, 136.5, 50),
|
||||
}
|
||||
X, y = load_wine(return_X_y=True)
|
||||
for key, value in test.items():
|
||||
clf = Models.get_model(key, random_state=1)
|
||||
clf.fit(X, y)
|
||||
# print(key, Models.get_complexity(key, clf))
|
||||
self.assertSequenceEqual(Models.get_complexity(key, clf), value)
|
@@ -6,11 +6,6 @@ from ..Utils import Folders, Files, Symbols, TextColor, EnvData, EnvDefault
|
||||
|
||||
|
||||
class UtilTest(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._random_state = 1
|
||||
self._kernels = ["liblinear", "linear", "rbf", "poly", "sigmoid"]
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def test_Folders(self):
|
||||
self.assertEqual("results", Folders.results)
|
||||
self.assertEqual("hidden_results", Folders.hidden_results)
|
||||
|
@@ -1,3 +1,4 @@
|
||||
from .Util_test import UtilTest
|
||||
from .Models_test import ModelTest
|
||||
|
||||
all = ["UtilTest"]
|
||||
all = ["UtilTest", "ModelTest"]
|
||||
|
Reference in New Issue
Block a user