Add XGBoost classifier

This commit is contained in:
2022-05-02 20:34:49 +02:00
parent e70209fe50
commit 01d9269880
3 changed files with 7 additions and 1 deletions

View File

@@ -9,6 +9,7 @@ from sklearn.svm import SVC
from stree import Stree
from wodt import Wodt
from odte import Odte
from xgboost import XGBClassifier
class Models:
@@ -39,6 +40,8 @@ class Models:
return BaggingClassifier(
base_estimator=clf, random_state=random_state
)
if name == "XGBoost":
return XGBClassifier(random_state=random_state)
if name == "AdaBoostStree":
clf = Stree(
random_state=random_state,
@@ -77,7 +80,7 @@ class Models:
leaves = mean([x.get_n_leaves() for x in result.estimators_])
depth = mean([x.get_depth() for x in result.estimators_])
nodes = mean([x.tree_.node_count for x in result.estimators_])
elif name == "SVC":
elif name == "SVC" or name == "XGBoost":
nodes = leaves = depth = 0
else:
nodes, leaves = result.nodes_leaves()

View File

@@ -11,6 +11,7 @@ from sklearn.datasets import load_wine
from stree import Stree
from wodt import Wodt
from odte import Odte
from xgboost import XGBClassifier
from .TestBase import TestBase
from ..Models import Models
@@ -25,6 +26,7 @@ class ModelTest(TestBase):
"SVC": SVC,
"RandomForest": RandomForestClassifier,
"ExtraTree": ExtraTreeClassifier,
"XGBoost": XGBClassifier,
}
for key, value in test.items():
self.assertIsInstance(Models.get_model(key), value)

View File

@@ -5,3 +5,4 @@ mufs
xlsxwriter
openpyxl
tqdm
xgboost