Add XGBoost classifier

This commit is contained in:
2022-05-02 20:34:49 +02:00
parent e70209fe50
commit 01d9269880
3 changed files with 7 additions and 1 deletions

View File

@@ -9,6 +9,7 @@ from sklearn.svm import SVC
from stree import Stree from stree import Stree
from wodt import Wodt from wodt import Wodt
from odte import Odte from odte import Odte
from xgboost import XGBClassifier
class Models: class Models:
@@ -39,6 +40,8 @@ class Models:
return BaggingClassifier( return BaggingClassifier(
base_estimator=clf, random_state=random_state base_estimator=clf, random_state=random_state
) )
if name == "XGBoost":
return XGBClassifier(random_state=random_state)
if name == "AdaBoostStree": if name == "AdaBoostStree":
clf = Stree( clf = Stree(
random_state=random_state, random_state=random_state,
@@ -77,7 +80,7 @@ class Models:
leaves = mean([x.get_n_leaves() for x in result.estimators_]) leaves = mean([x.get_n_leaves() for x in result.estimators_])
depth = mean([x.get_depth() for x in result.estimators_]) depth = mean([x.get_depth() for x in result.estimators_])
nodes = mean([x.tree_.node_count for x in result.estimators_]) nodes = mean([x.tree_.node_count for x in result.estimators_])
elif name == "SVC": elif name == "SVC" or name == "XGBoost":
nodes = leaves = depth = 0 nodes = leaves = depth = 0
else: else:
nodes, leaves = result.nodes_leaves() nodes, leaves = result.nodes_leaves()

View File

@@ -11,6 +11,7 @@ from sklearn.datasets import load_wine
from stree import Stree from stree import Stree
from wodt import Wodt from wodt import Wodt
from odte import Odte from odte import Odte
from xgboost import XGBClassifier
from .TestBase import TestBase from .TestBase import TestBase
from ..Models import Models from ..Models import Models
@@ -25,6 +26,7 @@ class ModelTest(TestBase):
"SVC": SVC, "SVC": SVC,
"RandomForest": RandomForestClassifier, "RandomForest": RandomForestClassifier,
"ExtraTree": ExtraTreeClassifier, "ExtraTree": ExtraTreeClassifier,
"XGBoost": XGBClassifier,
} }
for key, value in test.items(): for key, value in test.items():
self.assertIsInstance(Models.get_model(key), value) self.assertIsInstance(Models.get_model(key), value)

View File

@@ -5,3 +5,4 @@ mufs
xlsxwriter xlsxwriter
openpyxl openpyxl
tqdm tqdm
xgboost