mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 17:06:01 +00:00
#3 Rewrite some tests & remove use_predictions
Remove use_predictions parameter as of now, the model always use it
This commit is contained in:
@@ -126,14 +126,12 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
random_state: int = None,
|
||||
max_depth: int = None,
|
||||
tol: float = 1e-4,
|
||||
use_predictions: bool = False,
|
||||
min_samples_split: int = 0,
|
||||
):
|
||||
self.max_iter = max_iter
|
||||
self.C = C
|
||||
self.kernel = kernel
|
||||
self.random_state = random_state
|
||||
self.use_predictions = use_predictions
|
||||
self.max_depth = max_depth
|
||||
self.tol = tol
|
||||
self.min_samples_split = min_samples_split
|
||||
@@ -172,6 +170,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
:rtype: list
|
||||
"""
|
||||
up = ~down
|
||||
print(self.kernel, up.shape, down.shape)
|
||||
return (
|
||||
origin[up[:, 0]] if any(up) else None,
|
||||
origin[down[:, 0]] if any(down) else None,
|
||||
@@ -188,14 +187,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
the hyperplane of the node
|
||||
:rtype: np.array
|
||||
"""
|
||||
if self.use_predictions:
|
||||
res = np.expand_dims(node._clf.decision_function(data), 1)
|
||||
else:
|
||||
# doesn't work with multiclass as each sample has to do inner
|
||||
# product with its own coefficients computes positition of every
|
||||
# sample is w.r.t. the hyperplane
|
||||
res = self._linear_function(data, node)
|
||||
return res
|
||||
return np.expand_dims(node._clf.decision_function(data), 1)
|
||||
|
||||
def _split_criteria(self, data: np.array) -> np.array:
|
||||
"""Set the criteria to split arrays
|
||||
|
Reference in New Issue
Block a user