diff --git a/.travis.yml b/.travis.yml index 1725523..8007f8f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,7 +13,7 @@ notifications: # command to run tests script: - black --check --diff stree - - flake8 --count --exclude __init__.py stree + - flake8 --count stree - coverage run -m unittest -v stree.tests after_success: - codecov diff --git a/stree/Strees_grapher.py b/stree/Strees_grapher.py index c9c425e..12b2b47 100644 --- a/stree/Strees_grapher.py +++ b/stree/Strees_grapher.py @@ -146,7 +146,9 @@ class Stree_grapher(Stree): mirror.set_up(self._copy_tree(node.get_up())) return mirror - def fit(self, X: np.array, y: np.array) -> Stree: + def fit( + self, X: np.array, y: np.array, sample_weight: np.array = None + ) -> "Stree_grapher": """Fit the Stree and copy the tree in a Snode_graph tree :param X: Dataset @@ -159,10 +161,10 @@ class Stree_grapher(Stree): if X.shape[1] != 3: self._pca = PCA(n_components=3) X = self._pca.fit_transform(X) - res = super().fit(X, y) + super().fit(X, y, sample_weight=sample_weight) self._tree_gr = self._copy_tree(self.tree_) self._fitted = True - return res + return self def score(self, X: np.array, y: np.array) -> float: self._check_fitted() diff --git a/stree/tests/Strees_grapher_test.py b/stree/tests/Strees_grapher_test.py index e702dcc..9c3f874 100644 --- a/stree/tests/Strees_grapher_test.py +++ b/stree/tests/Strees_grapher_test.py @@ -181,6 +181,7 @@ class Snode_graph_test(unittest.TestCase): os.remove(file_name) def test_plot_hyperplane_with_distribution(self): + plt.close() with warnings.catch_warnings(): warnings.simplefilter("ignore") matplotlib.use("Agg") @@ -190,6 +191,7 @@ class Snode_graph_test(unittest.TestCase): self.assertEqual(1, num_figures_after - num_figures_before) def test_plot_hyperplane_without_distribution(self): + plt.close() with warnings.catch_warnings(): warnings.simplefilter("ignore") matplotlib.use("Agg") @@ -199,6 +201,7 @@ class Snode_graph_test(unittest.TestCase): self.assertEqual(1, num_figures_after - num_figures_before) def test_plot_distribution(self): + plt.close() with warnings.catch_warnings(): warnings.simplefilter("ignore") matplotlib.use("Agg")