From 9334951d1b84d9fb3420054b4be370d7dad91bd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Montan=CC=83ana?= Date: Mon, 15 Jun 2020 11:09:11 +0200 Subject: [PATCH] #2 Cosmetic and style updates --- stree/Strees.py | 8 +++++--- stree/tests/Splitter_test.py | 2 +- stree/tests/Stree_test.py | 8 ++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/stree/Strees.py b/stree/Strees.py index 7624972..ceeed7a 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -268,7 +268,8 @@ class Splitter: data = self.decision_criteria(data, node._y) self._down = data > 0 - def _distances(self, node: Snode, data: np.ndarray) -> np.array: + @staticmethod + def _distances(node: Snode, data: np.ndarray) -> np.array: """Compute distances of the samples to the hyperplane of the node :param node: node containing the svm classifier @@ -498,7 +499,8 @@ class Stree(BaseEstimator, ClassifierMixin): ) ) - def _reorder_results(self, y: np.array, indices: np.array) -> np.array: + @staticmethod + def _reorder_results(y: np.array, indices: np.array) -> np.array: """Reorder an array based on the array of indices passed :param y: data untidy @@ -579,7 +581,7 @@ class Stree(BaseEstimator, ClassifierMixin): X, y = check_X_y(X, y) y_pred = self.predict(X).reshape(y.shape) # Compute accuracy for each possible representation - y_type, y_true, y_pred = _check_targets(y, y_pred) + _, y_true, y_pred = _check_targets(y, y_pred) check_consistent_length(y_true, y_pred, sample_weight) score = y_true == y_pred return _weighted_sum(score, sample_weight, normalize=True) diff --git a/stree/tests/Splitter_test.py b/stree/tests/Splitter_test.py index b620ce1..68c6123 100644 --- a/stree/tests/Splitter_test.py +++ b/stree/tests/Splitter_test.py @@ -13,8 +13,8 @@ class Splitter_test(unittest.TestCase): self._random_state = 1 super().__init__(*args, **kwargs) + @staticmethod def build( - self, clf=LinearSVC(), min_samples_split=0, splitter_type="random", diff --git a/stree/tests/Stree_test.py b/stree/tests/Stree_test.py index 0fea9e5..ccc0442 100644 --- a/stree/tests/Stree_test.py +++ b/stree/tests/Stree_test.py @@ -67,9 +67,8 @@ class Stree_test(unittest.TestCase): clf.fit(*load_dataset(self._random_state)) self._check_tree(clf.tree_) - def _find_out( - self, px: np.array, x_original: np.array, y_original - ) -> list: + @staticmethod + def _find_out(px: np.array, x_original: np.array, y_original) -> list: """Find the original values of y for a given array of samples Arguments: @@ -163,7 +162,8 @@ class Stree_test(unittest.TestCase): self.assertListEqual(expected, computed) self.assertEqual(expected_string, str(clf)) - def test_is_a_sklearn_classifier(self): + @staticmethod + def test_is_a_sklearn_classifier(): import warnings from sklearn.exceptions import ConvergenceWarning