Adapt some notebooks

This commit is contained in:
2020-05-30 11:09:59 +02:00
parent a22ae81b54
commit 724a4855fb
5 changed files with 279 additions and 78 deletions

View File

@@ -152,11 +152,6 @@ class Stree(BaseEstimator, ClassifierMixin):
# doesn't work with multiclass as each sample has to do inner product with its own coeficients
# computes positition of every sample is w.r.t. the hyperplane
res = self._linear_function(data, node)
# data_up, data_down = self._split_array(data, down)
# indices_up, indices_down = self._split_array(indices, down)
# res_up, res_down = self._split_array(res, down)
# weight_up, weight_down = self._split_array(weights, down)
#return [data_up, indices_up, data_down, indices_down, weight_up, weight_down, res_up, res_down]
return res
def _split_criteria(self, data: np.array) -> np.array:
@@ -176,7 +171,6 @@ class Stree(BaseEstimator, ClassifierMixin):
sample_weight = _check_sample_weight(sample_weight, X)
check_classification_targets(y)
# Initialize computed parameters
#self.random_state = check_random_state(self.random_state)
self.classes_ = np.unique(y)
self.n_iter_ = self.max_iter
self.depth_ = 0
@@ -316,8 +310,7 @@ class Stree(BaseEstimator, ClassifierMixin):
# sklearn check
check_is_fitted(self)
yp = self.predict(X).reshape(y.shape)
right = (yp == y).astype(int)
return np.sum(right) / len(y)
return np.mean(yp == y)
def __iter__(self) -> Siterator:
try: