mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 23:46:02 +00:00
Fix parameter missing in method overload
This commit is contained in:
@@ -13,7 +13,7 @@ notifications:
|
|||||||
# command to run tests
|
# command to run tests
|
||||||
script:
|
script:
|
||||||
- black --check --diff stree
|
- black --check --diff stree
|
||||||
- flake8 --count --exclude __init__.py stree
|
- flake8 --count stree
|
||||||
- coverage run -m unittest -v stree.tests
|
- coverage run -m unittest -v stree.tests
|
||||||
after_success:
|
after_success:
|
||||||
- codecov
|
- codecov
|
||||||
|
@@ -146,7 +146,9 @@ class Stree_grapher(Stree):
|
|||||||
mirror.set_up(self._copy_tree(node.get_up()))
|
mirror.set_up(self._copy_tree(node.get_up()))
|
||||||
return mirror
|
return mirror
|
||||||
|
|
||||||
def fit(self, X: np.array, y: np.array) -> Stree:
|
def fit(
|
||||||
|
self, X: np.array, y: np.array, sample_weight: np.array = None
|
||||||
|
) -> "Stree_grapher":
|
||||||
"""Fit the Stree and copy the tree in a Snode_graph tree
|
"""Fit the Stree and copy the tree in a Snode_graph tree
|
||||||
|
|
||||||
:param X: Dataset
|
:param X: Dataset
|
||||||
@@ -159,10 +161,10 @@ class Stree_grapher(Stree):
|
|||||||
if X.shape[1] != 3:
|
if X.shape[1] != 3:
|
||||||
self._pca = PCA(n_components=3)
|
self._pca = PCA(n_components=3)
|
||||||
X = self._pca.fit_transform(X)
|
X = self._pca.fit_transform(X)
|
||||||
res = super().fit(X, y)
|
super().fit(X, y, sample_weight=sample_weight)
|
||||||
self._tree_gr = self._copy_tree(self.tree_)
|
self._tree_gr = self._copy_tree(self.tree_)
|
||||||
self._fitted = True
|
self._fitted = True
|
||||||
return res
|
return self
|
||||||
|
|
||||||
def score(self, X: np.array, y: np.array) -> float:
|
def score(self, X: np.array, y: np.array) -> float:
|
||||||
self._check_fitted()
|
self._check_fitted()
|
||||||
|
@@ -181,6 +181,7 @@ class Snode_graph_test(unittest.TestCase):
|
|||||||
os.remove(file_name)
|
os.remove(file_name)
|
||||||
|
|
||||||
def test_plot_hyperplane_with_distribution(self):
|
def test_plot_hyperplane_with_distribution(self):
|
||||||
|
plt.close()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
@@ -190,6 +191,7 @@ class Snode_graph_test(unittest.TestCase):
|
|||||||
self.assertEqual(1, num_figures_after - num_figures_before)
|
self.assertEqual(1, num_figures_after - num_figures_before)
|
||||||
|
|
||||||
def test_plot_hyperplane_without_distribution(self):
|
def test_plot_hyperplane_without_distribution(self):
|
||||||
|
plt.close()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
@@ -199,6 +201,7 @@ class Snode_graph_test(unittest.TestCase):
|
|||||||
self.assertEqual(1, num_figures_after - num_figures_before)
|
self.assertEqual(1, num_figures_after - num_figures_before)
|
||||||
|
|
||||||
def test_plot_distribution(self):
|
def test_plot_distribution(self):
|
||||||
|
plt.close()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
|
Reference in New Issue
Block a user