Implement ovo strategy (#37)

* Implement ovo strategy
* Set ovo strategy as default
* Add kernel liblinear with LinearSVC classifier
* Fix weak test
This commit is contained in:
Ricardo Montañana Gómez
2021-05-10 12:16:53 +02:00
committed by GitHub
parent 5cef0f4875
commit 4f04e72670
6 changed files with 252 additions and 92 deletions

View File

@@ -155,6 +155,10 @@ class Siterator:
self._stack = []
self._push(tree)
def __iter__(self):
# To complete the iterator interface
return self
def _push(self, node: Snode):
if node is not None:
self._stack.append(node)
@@ -373,16 +377,17 @@ class Splitter:
tuple
indices of the features selected
"""
# No feature reduction
if dataset.shape[1] == max_features:
# No feature reduction applies
return tuple(range(dataset.shape[1]))
# Random feature reduction
if self._feature_select == "random":
features_sets = self._generate_spaces(
dataset.shape[1], max_features
)
return self._select_best_set(dataset, labels, features_sets)
# return the KBest features
if self._feature_select == "best":
# Take KBest features
return (
SelectKBest(k=max_features)
.fit(dataset, labels)
@@ -569,6 +574,7 @@ class Stree(BaseEstimator, ClassifierMixin):
min_samples_split: int = 0,
max_features=None,
splitter: str = "random",
multiclass_strategy: str = "ovo",
normalize: bool = False,
):
self.max_iter = max_iter
@@ -585,6 +591,7 @@ class Stree(BaseEstimator, ClassifierMixin):
self.criterion = criterion
self.splitter = splitter
self.normalize = normalize
self.multiclass_strategy = multiclass_strategy
def _more_tags(self) -> dict:
"""Required by sklearn to supply features of the classifier
@@ -629,7 +636,23 @@ class Stree(BaseEstimator, ClassifierMixin):
f"Maximum depth has to be greater than 1... got (max_depth=\
{self.max_depth})"
)
kernels = ["linear", "rbf", "poly", "sigmoid"]
if self.multiclass_strategy not in ["ovr", "ovo"]:
raise ValueError(
"mutliclass_strategy has to be either ovr or ovo"
f" but got {self.multiclass_strategy}"
)
if self.multiclass_strategy == "ovo":
if self.kernel == "liblinear":
raise ValueError(
"The kernel liblinear is incompatible with ovo "
"multiclass_strategy"
)
if self.split_criteria == "max_samples":
raise ValueError(
"The multiclass_strategy 'ovo' is incompatible with "
"split_criteria 'max_samples'"
)
kernels = ["liblinear", "linear", "rbf", "poly", "sigmoid"]
if self.kernel not in kernels:
raise ValueError(f"Kernel {self.kernel} not in {kernels}")
check_classification_targets(y)
@@ -749,7 +772,7 @@ class Stree(BaseEstimator, ClassifierMixin):
C=self.C,
tol=self.tol,
)
if self.kernel == "linear"
if self.kernel == "liblinear"
else SVC(
kernel=self.kernel,
max_iter=self.max_iter,
@@ -758,6 +781,7 @@ class Stree(BaseEstimator, ClassifierMixin):
gamma=self.gamma,
degree=self.degree,
random_state=self.random_state,
decision_function_shape=self.multiclass_strategy,
)
)