diff --git a/stree/Strees.py b/stree/Strees.py index 2f1b219..5a8a8fd 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -126,6 +126,7 @@ class Stree(BaseEstimator, ClassifierMixin): random_state: int = None, max_depth: int = None, tol: float = 1e-4, + gamma="scale", min_samples_split: int = 0, ): self.max_iter = max_iter @@ -134,6 +135,7 @@ class Stree(BaseEstimator, ClassifierMixin): self.random_state = random_state self.max_depth = max_depth self.tol = tol + self.gamma = gamma self.min_samples_split = min_samples_split def _more_tags(self) -> dict: @@ -144,21 +146,6 @@ class Stree(BaseEstimator, ClassifierMixin): """ return {"binary_only": True, "requires_y": True} - def _linear_function(self, data: np.array, node: Snode) -> np.array: - """Compute the distance of set of samples to a hyperplane, in - multiclass classification it should compute the distance to a - hyperplane of each class - - :param data: dataset of samples - :type data: np.array shape(m, n) - :param node: the node that contains the hyperplance coefficients - :type node: Snode shape(1, n) - :return: array of distances of each sample to the hyperplane - :rtype: np.array - """ - coef = node._clf.coef_[0, :].reshape(-1, data.shape[1]) - return data.dot(coef.T) + node._clf.intercept_[0] - def _split_array(self, origin: np.array, down: np.array) -> list: """Split an array in two based on indices passed as down and its complement @@ -170,7 +157,6 @@ class Stree(BaseEstimator, ClassifierMixin): :rtype: list """ up = ~down - print(self.kernel, up.shape, down.shape) return ( origin[up[:, 0]] if any(up) else None, origin[down[:, 0]] if any(down) else None, @@ -187,7 +173,12 @@ class Stree(BaseEstimator, ClassifierMixin): the hyperplane of the node :rtype: np.array """ - return np.expand_dims(node._clf.decision_function(data), 1) + res = node._clf.decision_function(data) + if res.ndim == 1: + return np.expand_dims(res, 1) + elif res.shape[1] > 1: + res = np.delete(res, slice(1, res.shape[1]), axis=1) + return res def _split_criteria(self, data: np.array) -> np.array: """Set the criteria to split arrays @@ -256,9 +247,8 @@ class Stree(BaseEstimator, ClassifierMixin): run_tree(self.tree_) def _build_clf(self): - """ Select the correct classifier for the node + """ Build the correct classifier for the node """ - return ( LinearSVC( max_iter=self.max_iter, @@ -272,6 +262,7 @@ class Stree(BaseEstimator, ClassifierMixin): max_iter=self.max_iter, tol=self.tol, C=self.C, + gamma=self.gamma, ) ) diff --git a/stree/Strees_grapher.py b/stree/Strees_grapher.py index 0033fe1..8e93631 100644 --- a/stree/Strees_grapher.py +++ b/stree/Strees_grapher.py @@ -41,6 +41,9 @@ class Snode_graph(Snode): def set_axis_limits(self, limits: tuple): self._xlimits, self._ylimits, self._zlimits = limits + def get_axis_limits(self) -> tuple: + return self._xlimits, self._ylimits, self._zlimits + def _set_graphics_axis(self, ax: Axes3D): ax.set_xlim(self._xlimits) ax.set_ylim(self._ylimits) @@ -50,7 +53,7 @@ class Snode_graph(Snode): self, save_folder: str = "./", save_prefix: str = "", save_seq: int = 1 ): _, fig = self.plot_hyperplane() - name = f"{save_folder}{save_prefix}STnode{save_seq}.png" + name = os.path.join(save_folder, f"{save_prefix}STnode{save_seq}.png") fig.savefig(name, bbox_inches="tight") plt.close(fig) diff --git a/stree/tests/Strees_grapher_test.py b/stree/tests/Strees_grapher_test.py index 26d615a..c7f6f15 100644 --- a/stree/tests/Strees_grapher_test.py +++ b/stree/tests/Strees_grapher_test.py @@ -8,7 +8,7 @@ import matplotlib.pyplot as plt import warnings from sklearn.datasets import make_classification -from stree import Stree_grapher, Snode_graph +from stree import Stree_grapher, Snode_graph, Snode def get_dataset(random_state=0, n_features=3): @@ -30,18 +30,14 @@ def get_dataset(random_state=0, n_features=3): class Stree_grapher_test(unittest.TestCase): def __init__(self, *args, **kwargs): - os.environ["TESTING"] = "1" self._random_state = 1 self._clf = Stree_grapher(dict(random_state=self._random_state)) self._clf.fit(*get_dataset(self._random_state, n_features=4)) super().__init__(*args, **kwargs) @classmethod - def tearDownClass(cls): - try: - os.environ.pop("TESTING") - except KeyError: - pass + def setUp(cls): + os.environ["TESTING"] = "1" def test_iterator(self): """Check preorder iterator @@ -73,7 +69,9 @@ class Stree_grapher_test(unittest.TestCase): self.assertGreater(accuracy_score, 0.86) def test_save_all(self): - folder_name = "/tmp/" + folder_name = os.path.join(os.sep, "tmp", "stree") + if os.path.isdir(folder_name): + os.rmdir(folder_name) file_names = [ os.path.join(folder_name, f"STnode{i}.png") for i in range(1, 8) ] @@ -85,6 +83,7 @@ class Stree_grapher_test(unittest.TestCase): self.assertTrue(os.path.exists(file_name)) self.assertEqual("png", imghdr.what(file_name)) os.remove(file_name) + os.rmdir(folder_name) def test_plot_all(self): with warnings.catch_warnings(): @@ -98,20 +97,14 @@ class Stree_grapher_test(unittest.TestCase): class Snode_graph_test(unittest.TestCase): def __init__(self, *args, **kwargs): - os.environ["TESTING"] = "1" self._random_state = 1 self._clf = Stree_grapher(dict(random_state=self._random_state)) self._clf.fit(*get_dataset(self._random_state)) super().__init__(*args, **kwargs) @classmethod - def tearDownClass(cls): - """Remove the testing environ variable - """ - try: - os.environ.pop("TESTING") - except KeyError: - pass + def setUp(cls): + os.environ["TESTING"] = "1" def test_plot_size(self): default = self._clf._tree_gr.get_plot_size() @@ -205,3 +198,14 @@ class Snode_graph_test(unittest.TestCase): self._clf._tree_gr.plot_distribution() num_figures_after = plt.gcf().number self.assertEqual(1, num_figures_after - num_figures_before) + + def test_set_axis_limits(self): + node = Snode_graph(Snode(None, None, None, "test")) + limits = (-2, 2), (-3, 3), (-4, 4) + node.set_axis_limits(limits) + computed = node.get_axis_limits() + x, y, z = limits + xx, yy, zz = computed + self.assertEqual(x, xx) + self.assertEqual(y, yy) + self.assertEqual(z, zz) diff --git a/stree/tests/Strees_test.py b/stree/tests/Strees_test.py index a94aa19..fe3cd1f 100644 --- a/stree/tests/Strees_test.py +++ b/stree/tests/Strees_test.py @@ -26,17 +26,13 @@ def get_dataset(random_state=0): class Stree_test(unittest.TestCase): def __init__(self, *args, **kwargs): - os.environ["TESTING"] = "1" self._random_state = 1 self._kernels = ["linear", "rbf", "poly"] super().__init__(*args, **kwargs) @classmethod - def tearDownClass(cls): - try: - os.environ.pop("TESTING") - except KeyError: - pass + def setUp(cls): + os.environ["TESTING"] = "1" def _check_tree(self, node: Snode): """Check recursively that the nodes that are not leaves have the @@ -79,6 +75,9 @@ class Stree_test(unittest.TestCase): def test_build_tree(self): """Check if the tree is built the same way as predictions of models """ + import warnings + + warnings.filterwarnings("ignore") for kernel in self._kernels: clf = Stree(kernel=kernel, random_state=self._random_state) clf.fit(*get_dataset(self._random_state)) @@ -260,20 +259,14 @@ class Stree_test(unittest.TestCase): class Snode_test(unittest.TestCase): def __init__(self, *args, **kwargs): - os.environ["TESTING"] = "1" self._random_state = 1 self._clf = Stree(random_state=self._random_state) self._clf.fit(*get_dataset(self._random_state)) super().__init__(*args, **kwargs) @classmethod - def tearDownClass(cls): - """[summary] - """ - try: - os.environ.pop("TESTING") - except KeyError: - pass + def setUp(cls): + os.environ["TESTING"] = "1" def test_attributes_in_leaves(self): """Check if the attributes in leaves have correct values so they form a