Update Models_tests

This commit is contained in:
Ricardo Montañana Gómez
2023-01-14 13:05:44 +01:00
parent acfbafbdce
commit 7ef88bd5c7
3 changed files with 7 additions and 8 deletions

View File

@@ -193,7 +193,6 @@ class Datasets:
} }
def load(self, name, dataframe=False): def load(self, name, dataframe=False):
try: try:
class_name = self.class_names[self.data_sets.index(name)] class_name = self.class_names[self.data_sets.index(name)]
X, y = self.dataset.load(name, class_name) X, y = self.dataset.load(name, class_name)

View File

@@ -47,20 +47,20 @@ class Models:
"Wodt": Wodt(random_state=random_state), "Wodt": Wodt(random_state=random_state),
"SVC": SVC(random_state=random_state), "SVC": SVC(random_state=random_state),
"ODTE": Odte( "ODTE": Odte(
base_estimator=Stree(random_state=random_state), estimator=Stree(random_state=random_state),
random_state=random_state, random_state=random_state,
), ),
"BaggingStree": BaggingClassifier( "BaggingStree": BaggingClassifier(
base_estimator=Stree(random_state=random_state), estimator=Stree(random_state=random_state),
random_state=random_state, random_state=random_state,
), ),
"BaggingWodt": BaggingClassifier( "BaggingWodt": BaggingClassifier(
base_estimator=Wodt(random_state=random_state), estimator=Wodt(random_state=random_state),
random_state=random_state, random_state=random_state,
), ),
"XGBoost": XGBClassifier(random_state=random_state), "XGBoost": XGBClassifier(random_state=random_state),
"AdaBoostStree": AdaBoostClassifier( "AdaBoostStree": AdaBoostClassifier(
base_estimator=Stree( estimator=Stree(
random_state=random_state, random_state=random_state,
), ),
algorithm="SAMME", algorithm="SAMME",

View File

@@ -70,19 +70,19 @@ class ModelTest(TestBase):
def test_BaggingStree(self): def test_BaggingStree(self):
clf = Models.get_model("BaggingStree") clf = Models.get_model("BaggingStree")
self.assertIsInstance(clf, BaggingClassifier) self.assertIsInstance(clf, BaggingClassifier)
clf_base = clf.base_estimator clf_base = clf.estimator
self.assertIsInstance(clf_base, Stree) self.assertIsInstance(clf_base, Stree)
def test_BaggingWodt(self): def test_BaggingWodt(self):
clf = Models.get_model("BaggingWodt") clf = Models.get_model("BaggingWodt")
self.assertIsInstance(clf, BaggingClassifier) self.assertIsInstance(clf, BaggingClassifier)
clf_base = clf.base_estimator clf_base = clf.estimator
self.assertIsInstance(clf_base, Wodt) self.assertIsInstance(clf_base, Wodt)
def test_AdaBoostStree(self): def test_AdaBoostStree(self):
clf = Models.get_model("AdaBoostStree") clf = Models.get_model("AdaBoostStree")
self.assertIsInstance(clf, AdaBoostClassifier) self.assertIsInstance(clf, AdaBoostClassifier)
clf_base = clf.base_estimator clf_base = clf.estimator
self.assertIsInstance(clf_base, Stree) self.assertIsInstance(clf_base, Stree)
def test_unknown_classifier(self): def test_unknown_classifier(self):