mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-16 07:55:54 +00:00
Add XGBoost classifier
This commit is contained in:
@@ -9,6 +9,7 @@ from sklearn.svm import SVC
|
|||||||
from stree import Stree
|
from stree import Stree
|
||||||
from wodt import Wodt
|
from wodt import Wodt
|
||||||
from odte import Odte
|
from odte import Odte
|
||||||
|
from xgboost import XGBClassifier
|
||||||
|
|
||||||
|
|
||||||
class Models:
|
class Models:
|
||||||
@@ -39,6 +40,8 @@ class Models:
|
|||||||
return BaggingClassifier(
|
return BaggingClassifier(
|
||||||
base_estimator=clf, random_state=random_state
|
base_estimator=clf, random_state=random_state
|
||||||
)
|
)
|
||||||
|
if name == "XGBoost":
|
||||||
|
return XGBClassifier(random_state=random_state)
|
||||||
if name == "AdaBoostStree":
|
if name == "AdaBoostStree":
|
||||||
clf = Stree(
|
clf = Stree(
|
||||||
random_state=random_state,
|
random_state=random_state,
|
||||||
@@ -77,7 +80,7 @@ class Models:
|
|||||||
leaves = mean([x.get_n_leaves() for x in result.estimators_])
|
leaves = mean([x.get_n_leaves() for x in result.estimators_])
|
||||||
depth = mean([x.get_depth() for x in result.estimators_])
|
depth = mean([x.get_depth() for x in result.estimators_])
|
||||||
nodes = mean([x.tree_.node_count for x in result.estimators_])
|
nodes = mean([x.tree_.node_count for x in result.estimators_])
|
||||||
elif name == "SVC":
|
elif name == "SVC" or name == "XGBoost":
|
||||||
nodes = leaves = depth = 0
|
nodes = leaves = depth = 0
|
||||||
else:
|
else:
|
||||||
nodes, leaves = result.nodes_leaves()
|
nodes, leaves = result.nodes_leaves()
|
||||||
|
@@ -11,6 +11,7 @@ from sklearn.datasets import load_wine
|
|||||||
from stree import Stree
|
from stree import Stree
|
||||||
from wodt import Wodt
|
from wodt import Wodt
|
||||||
from odte import Odte
|
from odte import Odte
|
||||||
|
from xgboost import XGBClassifier
|
||||||
from .TestBase import TestBase
|
from .TestBase import TestBase
|
||||||
from ..Models import Models
|
from ..Models import Models
|
||||||
|
|
||||||
@@ -25,6 +26,7 @@ class ModelTest(TestBase):
|
|||||||
"SVC": SVC,
|
"SVC": SVC,
|
||||||
"RandomForest": RandomForestClassifier,
|
"RandomForest": RandomForestClassifier,
|
||||||
"ExtraTree": ExtraTreeClassifier,
|
"ExtraTree": ExtraTreeClassifier,
|
||||||
|
"XGBoost": XGBClassifier,
|
||||||
}
|
}
|
||||||
for key, value in test.items():
|
for key, value in test.items():
|
||||||
self.assertIsInstance(Models.get_model(key), value)
|
self.assertIsInstance(Models.get_model(key), value)
|
||||||
|
@@ -5,3 +5,4 @@ mufs
|
|||||||
xlsxwriter
|
xlsxwriter
|
||||||
openpyxl
|
openpyxl
|
||||||
tqdm
|
tqdm
|
||||||
|
xgboost
|
||||||
|
Reference in New Issue
Block a user