mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 07:56:06 +00:00
Fix parameter missing in method overload
This commit is contained in:
@@ -13,7 +13,7 @@ notifications:
|
||||
# command to run tests
|
||||
script:
|
||||
- black --check --diff stree
|
||||
- flake8 --count --exclude __init__.py stree
|
||||
- flake8 --count stree
|
||||
- coverage run -m unittest -v stree.tests
|
||||
after_success:
|
||||
- codecov
|
||||
|
@@ -146,7 +146,9 @@ class Stree_grapher(Stree):
|
||||
mirror.set_up(self._copy_tree(node.get_up()))
|
||||
return mirror
|
||||
|
||||
def fit(self, X: np.array, y: np.array) -> Stree:
|
||||
def fit(
|
||||
self, X: np.array, y: np.array, sample_weight: np.array = None
|
||||
) -> "Stree_grapher":
|
||||
"""Fit the Stree and copy the tree in a Snode_graph tree
|
||||
|
||||
:param X: Dataset
|
||||
@@ -159,10 +161,10 @@ class Stree_grapher(Stree):
|
||||
if X.shape[1] != 3:
|
||||
self._pca = PCA(n_components=3)
|
||||
X = self._pca.fit_transform(X)
|
||||
res = super().fit(X, y)
|
||||
super().fit(X, y, sample_weight=sample_weight)
|
||||
self._tree_gr = self._copy_tree(self.tree_)
|
||||
self._fitted = True
|
||||
return res
|
||||
return self
|
||||
|
||||
def score(self, X: np.array, y: np.array) -> float:
|
||||
self._check_fitted()
|
||||
|
@@ -181,6 +181,7 @@ class Snode_graph_test(unittest.TestCase):
|
||||
os.remove(file_name)
|
||||
|
||||
def test_plot_hyperplane_with_distribution(self):
|
||||
plt.close()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
matplotlib.use("Agg")
|
||||
@@ -190,6 +191,7 @@ class Snode_graph_test(unittest.TestCase):
|
||||
self.assertEqual(1, num_figures_after - num_figures_before)
|
||||
|
||||
def test_plot_hyperplane_without_distribution(self):
|
||||
plt.close()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
matplotlib.use("Agg")
|
||||
@@ -199,6 +201,7 @@ class Snode_graph_test(unittest.TestCase):
|
||||
self.assertEqual(1, num_figures_after - num_figures_before)
|
||||
|
||||
def test_plot_distribution(self):
|
||||
plt.close()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
matplotlib.use("Agg")
|
||||
|
Reference in New Issue
Block a user