mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-15 15:35:52 +00:00
Add XGBoost classifier
This commit is contained in:
@@ -9,6 +9,7 @@ from sklearn.svm import SVC
|
||||
from stree import Stree
|
||||
from wodt import Wodt
|
||||
from odte import Odte
|
||||
from xgboost import XGBClassifier
|
||||
|
||||
|
||||
class Models:
|
||||
@@ -39,6 +40,8 @@ class Models:
|
||||
return BaggingClassifier(
|
||||
base_estimator=clf, random_state=random_state
|
||||
)
|
||||
if name == "XGBoost":
|
||||
return XGBClassifier(random_state=random_state)
|
||||
if name == "AdaBoostStree":
|
||||
clf = Stree(
|
||||
random_state=random_state,
|
||||
@@ -77,7 +80,7 @@ class Models:
|
||||
leaves = mean([x.get_n_leaves() for x in result.estimators_])
|
||||
depth = mean([x.get_depth() for x in result.estimators_])
|
||||
nodes = mean([x.tree_.node_count for x in result.estimators_])
|
||||
elif name == "SVC":
|
||||
elif name == "SVC" or name == "XGBoost":
|
||||
nodes = leaves = depth = 0
|
||||
else:
|
||||
nodes, leaves = result.nodes_leaves()
|
||||
|
@@ -11,6 +11,7 @@ from sklearn.datasets import load_wine
|
||||
from stree import Stree
|
||||
from wodt import Wodt
|
||||
from odte import Odte
|
||||
from xgboost import XGBClassifier
|
||||
from .TestBase import TestBase
|
||||
from ..Models import Models
|
||||
|
||||
@@ -25,6 +26,7 @@ class ModelTest(TestBase):
|
||||
"SVC": SVC,
|
||||
"RandomForest": RandomForestClassifier,
|
||||
"ExtraTree": ExtraTreeClassifier,
|
||||
"XGBoost": XGBClassifier,
|
||||
}
|
||||
for key, value in test.items():
|
||||
self.assertIsInstance(Models.get_model(key), value)
|
||||
|
@@ -5,3 +5,4 @@ mufs
|
||||
xlsxwriter
|
||||
openpyxl
|
||||
tqdm
|
||||
xgboost
|
||||
|
Reference in New Issue
Block a user