diff --git a/benchmark/Models.py b/benchmark/Models.py index 627142d..1378218 100644 --- a/benchmark/Models.py +++ b/benchmark/Models.py @@ -9,6 +9,7 @@ from sklearn.svm import SVC from stree import Stree from wodt import Wodt from odte import Odte +from xgboost import XGBClassifier class Models: @@ -39,6 +40,8 @@ class Models: return BaggingClassifier( base_estimator=clf, random_state=random_state ) + if name == "XGBoost": + return XGBClassifier(random_state=random_state) if name == "AdaBoostStree": clf = Stree( random_state=random_state, @@ -77,7 +80,7 @@ class Models: leaves = mean([x.get_n_leaves() for x in result.estimators_]) depth = mean([x.get_depth() for x in result.estimators_]) nodes = mean([x.tree_.node_count for x in result.estimators_]) - elif name == "SVC": + elif name == "SVC" or name == "XGBoost": nodes = leaves = depth = 0 else: nodes, leaves = result.nodes_leaves() diff --git a/benchmark/tests/Models_test.py b/benchmark/tests/Models_test.py index 8bffb2d..9f27a11 100644 --- a/benchmark/tests/Models_test.py +++ b/benchmark/tests/Models_test.py @@ -11,6 +11,7 @@ from sklearn.datasets import load_wine from stree import Stree from wodt import Wodt from odte import Odte +from xgboost import XGBClassifier from .TestBase import TestBase from ..Models import Models @@ -25,6 +26,7 @@ class ModelTest(TestBase): "SVC": SVC, "RandomForest": RandomForestClassifier, "ExtraTree": ExtraTreeClassifier, + "XGBoost": XGBClassifier, } for key, value in test.items(): self.assertIsInstance(Models.get_model(key), value) diff --git a/requirements.txt b/requirements.txt index 9dedbe8..3af1412 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ mufs xlsxwriter openpyxl tqdm +xgboost