Fix parameter missing in method overload

This commit is contained in:
2020-06-06 18:18:03 +02:00
parent cb10aea36e
commit 37577849db
3 changed files with 9 additions and 4 deletions

View File

@@ -13,7 +13,7 @@ notifications:
# command to run tests # command to run tests
script: script:
- black --check --diff stree - black --check --diff stree
- flake8 --count --exclude __init__.py stree - flake8 --count stree
- coverage run -m unittest -v stree.tests - coverage run -m unittest -v stree.tests
after_success: after_success:
- codecov - codecov

View File

@@ -146,7 +146,9 @@ class Stree_grapher(Stree):
mirror.set_up(self._copy_tree(node.get_up())) mirror.set_up(self._copy_tree(node.get_up()))
return mirror return mirror
def fit(self, X: np.array, y: np.array) -> Stree: def fit(
self, X: np.array, y: np.array, sample_weight: np.array = None
) -> "Stree_grapher":
"""Fit the Stree and copy the tree in a Snode_graph tree """Fit the Stree and copy the tree in a Snode_graph tree
:param X: Dataset :param X: Dataset
@@ -159,10 +161,10 @@ class Stree_grapher(Stree):
if X.shape[1] != 3: if X.shape[1] != 3:
self._pca = PCA(n_components=3) self._pca = PCA(n_components=3)
X = self._pca.fit_transform(X) X = self._pca.fit_transform(X)
res = super().fit(X, y) super().fit(X, y, sample_weight=sample_weight)
self._tree_gr = self._copy_tree(self.tree_) self._tree_gr = self._copy_tree(self.tree_)
self._fitted = True self._fitted = True
return res return self
def score(self, X: np.array, y: np.array) -> float: def score(self, X: np.array, y: np.array) -> float:
self._check_fitted() self._check_fitted()

View File

@@ -181,6 +181,7 @@ class Snode_graph_test(unittest.TestCase):
os.remove(file_name) os.remove(file_name)
def test_plot_hyperplane_with_distribution(self): def test_plot_hyperplane_with_distribution(self):
plt.close()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
matplotlib.use("Agg") matplotlib.use("Agg")
@@ -190,6 +191,7 @@ class Snode_graph_test(unittest.TestCase):
self.assertEqual(1, num_figures_after - num_figures_before) self.assertEqual(1, num_figures_after - num_figures_before)
def test_plot_hyperplane_without_distribution(self): def test_plot_hyperplane_without_distribution(self):
plt.close()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
matplotlib.use("Agg") matplotlib.use("Agg")
@@ -199,6 +201,7 @@ class Snode_graph_test(unittest.TestCase):
self.assertEqual(1, num_figures_after - num_figures_before) self.assertEqual(1, num_figures_after - num_figures_before)
def test_plot_distribution(self): def test_plot_distribution(self):
plt.close()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
matplotlib.use("Agg") matplotlib.use("Agg")