#3 Complete multiclass in Stree

Add multiclass dimensions management in distances method
Add gamma hyperparameter for non linear kernels
This commit is contained in:
2020-06-08 13:54:24 +02:00
parent 3a48d8b405
commit d7c0bc3bc5
4 changed files with 41 additions and 50 deletions

View File

@@ -126,6 +126,7 @@ class Stree(BaseEstimator, ClassifierMixin):
random_state: int = None,
max_depth: int = None,
tol: float = 1e-4,
gamma="scale",
min_samples_split: int = 0,
):
self.max_iter = max_iter
@@ -134,6 +135,7 @@ class Stree(BaseEstimator, ClassifierMixin):
self.random_state = random_state
self.max_depth = max_depth
self.tol = tol
self.gamma = gamma
self.min_samples_split = min_samples_split
def _more_tags(self) -> dict:
@@ -144,21 +146,6 @@ class Stree(BaseEstimator, ClassifierMixin):
"""
return {"binary_only": True, "requires_y": True}
def _linear_function(self, data: np.array, node: Snode) -> np.array:
"""Compute the distance of set of samples to a hyperplane, in
multiclass classification it should compute the distance to a
hyperplane of each class
:param data: dataset of samples
:type data: np.array shape(m, n)
:param node: the node that contains the hyperplance coefficients
:type node: Snode shape(1, n)
:return: array of distances of each sample to the hyperplane
:rtype: np.array
"""
coef = node._clf.coef_[0, :].reshape(-1, data.shape[1])
return data.dot(coef.T) + node._clf.intercept_[0]
def _split_array(self, origin: np.array, down: np.array) -> list:
"""Split an array in two based on indices passed as down and its complement
@@ -170,7 +157,6 @@ class Stree(BaseEstimator, ClassifierMixin):
:rtype: list
"""
up = ~down
print(self.kernel, up.shape, down.shape)
return (
origin[up[:, 0]] if any(up) else None,
origin[down[:, 0]] if any(down) else None,
@@ -187,7 +173,12 @@ class Stree(BaseEstimator, ClassifierMixin):
the hyperplane of the node
:rtype: np.array
"""
return np.expand_dims(node._clf.decision_function(data), 1)
res = node._clf.decision_function(data)
if res.ndim == 1:
return np.expand_dims(res, 1)
elif res.shape[1] > 1:
res = np.delete(res, slice(1, res.shape[1]), axis=1)
return res
def _split_criteria(self, data: np.array) -> np.array:
"""Set the criteria to split arrays
@@ -256,9 +247,8 @@ class Stree(BaseEstimator, ClassifierMixin):
run_tree(self.tree_)
def _build_clf(self):
""" Select the correct classifier for the node
""" Build the correct classifier for the node
"""
return (
LinearSVC(
max_iter=self.max_iter,
@@ -272,6 +262,7 @@ class Stree(BaseEstimator, ClassifierMixin):
max_iter=self.max_iter,
tol=self.tol,
C=self.C,
gamma=self.gamma,
)
)