mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-16 07:55:54 +00:00
100% coverage in Models
This commit is contained in:
@@ -15,28 +15,41 @@ class Models:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(name, random_state=None):
|
def get_model(name, random_state=None):
|
||||||
if name == "STree":
|
if name == "STree":
|
||||||
return Stree()
|
return Stree(random_state=random_state)
|
||||||
if name == "Cart":
|
if name == "Cart":
|
||||||
return DecisionTreeClassifier()
|
return DecisionTreeClassifier(random_state=random_state)
|
||||||
if name == "ExtraTree":
|
if name == "ExtraTree":
|
||||||
return ExtraTreeClassifier()
|
return ExtraTreeClassifier(random_state=random_state)
|
||||||
if name == "Wodt":
|
if name == "Wodt":
|
||||||
return Wodt()
|
return Wodt(random_state=random_state)
|
||||||
if name == "SVC":
|
if name == "SVC":
|
||||||
return SVC()
|
return SVC(random_state=random_state)
|
||||||
if name == "ODTE":
|
if name == "ODTE":
|
||||||
return Odte(base_estimator=Stree())
|
return Odte(
|
||||||
|
base_estimator=Stree(random_state=random_state),
|
||||||
|
random_state=random_state,
|
||||||
|
)
|
||||||
if name == "BaggingStree":
|
if name == "BaggingStree":
|
||||||
clf = Stree(random_state=random_state)
|
clf = Stree(random_state=random_state)
|
||||||
return BaggingClassifier(base_estimator=clf)
|
return BaggingClassifier(
|
||||||
|
base_estimator=clf, random_state=random_state
|
||||||
|
)
|
||||||
if name == "BaggingWodt":
|
if name == "BaggingWodt":
|
||||||
clf = Wodt(random_state=random_state)
|
clf = Wodt(random_state=random_state)
|
||||||
return BaggingClassifier(base_estimator=clf)
|
return BaggingClassifier(
|
||||||
|
base_estimator=clf, random_state=random_state
|
||||||
|
)
|
||||||
if name == "AdaBoostStree":
|
if name == "AdaBoostStree":
|
||||||
clf = Stree(random_state=random_state)
|
clf = Stree(
|
||||||
return AdaBoostClassifier(base_estimator=clf)
|
random_state=random_state,
|
||||||
|
)
|
||||||
|
return AdaBoostClassifier(
|
||||||
|
base_estimator=clf,
|
||||||
|
algorithm="SAMME",
|
||||||
|
random_state=random_state,
|
||||||
|
)
|
||||||
if name == "RandomForest":
|
if name == "RandomForest":
|
||||||
return RandomForestClassifier()
|
return RandomForestClassifier(random_state=random_state)
|
||||||
msg = f"No model recognized {name}"
|
msg = f"No model recognized {name}"
|
||||||
if name in ("Stree", "stree"):
|
if name in ("Stree", "stree"):
|
||||||
msg += ", did you mean STree?"
|
msg += ", did you mean STree?"
|
||||||
@@ -55,18 +68,11 @@ class Models:
|
|||||||
leaves = result.get_n_leaves()
|
leaves = result.get_n_leaves()
|
||||||
depth = 0
|
depth = 0
|
||||||
elif name.startswith("Bagging") or name.startswith("AdaBoost"):
|
elif name.startswith("Bagging") or name.startswith("AdaBoost"):
|
||||||
if hasattr(result.base_estimator_, "nodes_leaves"):
|
nodes, leaves = list(
|
||||||
nodes, leaves = list(
|
zip(*[x.nodes_leaves() for x in result.estimators_])
|
||||||
zip(*[x.nodes_leaves() for x in result.estimators_])
|
)
|
||||||
)
|
nodes, leaves = mean(nodes), mean(leaves)
|
||||||
nodes, leaves = mean(nodes), mean(leaves)
|
depth = mean([x.depth_ for x in result.estimators_])
|
||||||
depth = mean([x.depth_ for x in result.estimators_])
|
|
||||||
elif hasattr(result.base_estimator_, "tree_"):
|
|
||||||
nodes = mean([x.tree_.node_count for x in result.estimators_])
|
|
||||||
leaves = mean([x.get_n_leaves() for x in result.estimators_])
|
|
||||||
depth = mean([x.get_depth() for x in result.estimators_])
|
|
||||||
else:
|
|
||||||
nodes = leaves = depth = 0
|
|
||||||
elif name == "RandomForest":
|
elif name == "RandomForest":
|
||||||
leaves = mean([x.get_n_leaves() for x in result.estimators_])
|
leaves = mean([x.get_n_leaves() for x in result.estimators_])
|
||||||
depth = mean([x.get_depth() for x in result.estimators_])
|
depth = mean([x.get_depth() for x in result.estimators_])
|
||||||
|
84
benchmark/tests/Models_test.py
Normal file
84
benchmark/tests/Models_test.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import unittest
|
||||||
|
import warnings
|
||||||
|
from sklearn.exceptions import ConvergenceWarning
|
||||||
|
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
|
||||||
|
from sklearn.ensemble import (
|
||||||
|
RandomForestClassifier,
|
||||||
|
BaggingClassifier,
|
||||||
|
AdaBoostClassifier,
|
||||||
|
)
|
||||||
|
from sklearn.svm import SVC
|
||||||
|
from sklearn.datasets import load_wine
|
||||||
|
from stree import Stree
|
||||||
|
from wodt import Wodt
|
||||||
|
from odte import Odte
|
||||||
|
from ..Models import Models
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTest(unittest.TestCase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def test_Models(self):
|
||||||
|
test = {
|
||||||
|
"STree": Stree,
|
||||||
|
"Wodt": Wodt,
|
||||||
|
"ODTE": Odte,
|
||||||
|
"Cart": DecisionTreeClassifier,
|
||||||
|
"SVC": SVC,
|
||||||
|
"RandomForest": RandomForestClassifier,
|
||||||
|
"ExtraTree": ExtraTreeClassifier,
|
||||||
|
}
|
||||||
|
for key, value in test.items():
|
||||||
|
self.assertIsInstance(Models.get_model(key), value)
|
||||||
|
|
||||||
|
def test_BaggingStree(self):
|
||||||
|
clf = Models.get_model("BaggingStree")
|
||||||
|
self.assertIsInstance(clf, BaggingClassifier)
|
||||||
|
clf_base = clf.base_estimator
|
||||||
|
self.assertIsInstance(clf_base, Stree)
|
||||||
|
|
||||||
|
def test_BaggingWodt(self):
|
||||||
|
clf = Models.get_model("BaggingWodt")
|
||||||
|
self.assertIsInstance(clf, BaggingClassifier)
|
||||||
|
clf_base = clf.base_estimator
|
||||||
|
self.assertIsInstance(clf_base, Wodt)
|
||||||
|
|
||||||
|
def test_AdaBoostStree(self):
|
||||||
|
clf = Models.get_model("AdaBoostStree")
|
||||||
|
self.assertIsInstance(clf, AdaBoostClassifier)
|
||||||
|
clf_base = clf.base_estimator
|
||||||
|
self.assertIsInstance(clf_base, Stree)
|
||||||
|
|
||||||
|
def test_unknown_classifier(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Models.get_model("unknown")
|
||||||
|
|
||||||
|
def test_bogus_Stree(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Models.get_model("Stree")
|
||||||
|
|
||||||
|
def test_bogus_Odte(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
Models.get_model("Odte")
|
||||||
|
|
||||||
|
def test_get_complexity(self):
|
||||||
|
warnings.filterwarnings("ignore", category=ConvergenceWarning)
|
||||||
|
test = {
|
||||||
|
"STree": (11, 6, 4),
|
||||||
|
"Wodt": (303, 152, 50),
|
||||||
|
"ODTE": (7.84, 4.42, 3.36),
|
||||||
|
"Cart": (23, 12, 5),
|
||||||
|
"SVC": (0, 0, 0),
|
||||||
|
"RandomForest": (21.3, 11, 5.26),
|
||||||
|
"ExtraTree": (0, 38, 0),
|
||||||
|
"AdaBoostStree": (12.25, 6.625, 4.75),
|
||||||
|
"BaggingStree": (8.4, 4.7, 3.5),
|
||||||
|
"BaggingWodt": (272, 136.5, 50),
|
||||||
|
}
|
||||||
|
X, y = load_wine(return_X_y=True)
|
||||||
|
for key, value in test.items():
|
||||||
|
clf = Models.get_model(key, random_state=1)
|
||||||
|
clf.fit(X, y)
|
||||||
|
# print(key, Models.get_complexity(key, clf))
|
||||||
|
self.assertSequenceEqual(Models.get_complexity(key, clf), value)
|
@@ -6,11 +6,6 @@ from ..Utils import Folders, Files, Symbols, TextColor, EnvData, EnvDefault
|
|||||||
|
|
||||||
|
|
||||||
class UtilTest(unittest.TestCase):
|
class UtilTest(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self._random_state = 1
|
|
||||||
self._kernels = ["liblinear", "linear", "rbf", "poly", "sigmoid"]
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def test_Folders(self):
|
def test_Folders(self):
|
||||||
self.assertEqual("results", Folders.results)
|
self.assertEqual("results", Folders.results)
|
||||||
self.assertEqual("hidden_results", Folders.hidden_results)
|
self.assertEqual("hidden_results", Folders.hidden_results)
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
from .Util_test import UtilTest
|
from .Util_test import UtilTest
|
||||||
|
from .Models_test import ModelTest
|
||||||
|
|
||||||
all = ["UtilTest"]
|
all = ["UtilTest", "ModelTest"]
|
||||||
|
Reference in New Issue
Block a user