mirror of
https://github.com/Doctorado-ML/benchmark.git
synced 2025-08-16 16:05:54 +00:00
Add BaggingStree and BaggingWodt models
This commit is contained in:
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -159,7 +159,6 @@ class Experiment:
|
|||||||
self.title = title
|
self.title = title
|
||||||
self.stratified = stratified == "1"
|
self.stratified = stratified == "1"
|
||||||
self.stratified_class = StratifiedKFold if self.stratified else KFold
|
self.stratified_class = StratifiedKFold if self.stratified else KFold
|
||||||
self.model = Models.get_model(model_name)
|
|
||||||
self.datasets = datasets
|
self.datasets = datasets
|
||||||
dictionary = json.loads(hyperparams_dict)
|
dictionary = json.loads(hyperparams_dict)
|
||||||
hyper = BestResults(
|
hyper = BestResults(
|
||||||
@@ -185,8 +184,10 @@ class Experiment:
|
|||||||
return self.output_file
|
return self.output_file
|
||||||
|
|
||||||
def _build_classifier(self, random_state, hyperparameters):
|
def _build_classifier(self, random_state, hyperparameters):
|
||||||
clf = self.model(random_state=random_state)
|
self.model = Models.get_model(self.model_name, random_state)
|
||||||
|
clf = self.model
|
||||||
clf.set_params(**hyperparameters)
|
clf.set_params(**hyperparameters)
|
||||||
|
clf.set_params(random_state=random_state)
|
||||||
return clf
|
return clf
|
||||||
|
|
||||||
def _init_experiment(self):
|
def _init_experiment(self):
|
||||||
|
@@ -9,23 +9,27 @@ from odte import Odte
|
|||||||
|
|
||||||
class Models:
|
class Models:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model(name):
|
def get_model(name, random_state=None):
|
||||||
if name == "STree":
|
if name == "STree":
|
||||||
return Stree
|
return Stree()
|
||||||
if name == "Cart":
|
if name == "Cart":
|
||||||
return DecisionTreeClassifier
|
return DecisionTreeClassifier()
|
||||||
if name == "ExtraTree":
|
if name == "ExtraTree":
|
||||||
return ExtraTreeClassifier
|
return ExtraTreeClassifier()
|
||||||
if name == "Wodt":
|
if name == "Wodt":
|
||||||
return Wodt
|
return Wodt()
|
||||||
if name == "SVC":
|
if name == "SVC":
|
||||||
return SVC
|
return SVC()
|
||||||
if name == "ODTE":
|
if name == "ODTE":
|
||||||
return Odte
|
return Odte()
|
||||||
if name == "Bagging":
|
if name == "BaggingStree":
|
||||||
return BaggingClassifier
|
clf = Stree(random_state=random_state)
|
||||||
|
return BaggingClassifier(base_estimator=clf)
|
||||||
|
if name == "BaggingWodt":
|
||||||
|
clf = Wodt(random_state=random_state)
|
||||||
|
return BaggingClassifier(base_estimator=clf)
|
||||||
if name == "RandomForest":
|
if name == "RandomForest":
|
||||||
return RandomForestClassifier
|
return RandomForestClassifier()
|
||||||
msg = f"No model recognized {name}"
|
msg = f"No model recognized {name}"
|
||||||
if name in ("Stree", "stree"):
|
if name in ("Stree", "stree"):
|
||||||
msg += ", did you mean STree?"
|
msg += ", did you mean STree?"
|
||||||
@@ -43,7 +47,7 @@ class Models:
|
|||||||
nodes = 0
|
nodes = 0
|
||||||
leaves = result.get_n_leaves()
|
leaves = result.get_n_leaves()
|
||||||
depth = 0
|
depth = 0
|
||||||
elif name == "Bagging":
|
elif name.startswith("Bagging"):
|
||||||
if hasattr(result.base_estimator_, "nodes_leaves"):
|
if hasattr(result.base_estimator_, "nodes_leaves"):
|
||||||
nodes, leaves = list(
|
nodes, leaves = list(
|
||||||
zip(*[x.nodes_leaves() for x in result.estimators_])
|
zip(*[x.nodes_leaves() for x in result.estimators_])
|
||||||
|
Reference in New Issue
Block a user