#3 Complete multiclass in Stree

Add multiclass dimensions management in distances method
Add gamma hyperparameter for non linear kernels
This commit is contained in:
2020-06-08 13:54:24 +02:00
parent 3a48d8b405
commit d7c0bc3bc5
4 changed files with 41 additions and 50 deletions

View File

@@ -126,6 +126,7 @@ class Stree(BaseEstimator, ClassifierMixin):
random_state: int = None, random_state: int = None,
max_depth: int = None, max_depth: int = None,
tol: float = 1e-4, tol: float = 1e-4,
gamma="scale",
min_samples_split: int = 0, min_samples_split: int = 0,
): ):
self.max_iter = max_iter self.max_iter = max_iter
@@ -134,6 +135,7 @@ class Stree(BaseEstimator, ClassifierMixin):
self.random_state = random_state self.random_state = random_state
self.max_depth = max_depth self.max_depth = max_depth
self.tol = tol self.tol = tol
self.gamma = gamma
self.min_samples_split = min_samples_split self.min_samples_split = min_samples_split
def _more_tags(self) -> dict: def _more_tags(self) -> dict:
@@ -144,21 +146,6 @@ class Stree(BaseEstimator, ClassifierMixin):
""" """
return {"binary_only": True, "requires_y": True} return {"binary_only": True, "requires_y": True}
def _linear_function(self, data: np.array, node: Snode) -> np.array:
"""Compute the distance of set of samples to a hyperplane, in
multiclass classification it should compute the distance to a
hyperplane of each class
:param data: dataset of samples
:type data: np.array shape(m, n)
:param node: the node that contains the hyperplance coefficients
:type node: Snode shape(1, n)
:return: array of distances of each sample to the hyperplane
:rtype: np.array
"""
coef = node._clf.coef_[0, :].reshape(-1, data.shape[1])
return data.dot(coef.T) + node._clf.intercept_[0]
def _split_array(self, origin: np.array, down: np.array) -> list: def _split_array(self, origin: np.array, down: np.array) -> list:
"""Split an array in two based on indices passed as down and its complement """Split an array in two based on indices passed as down and its complement
@@ -170,7 +157,6 @@ class Stree(BaseEstimator, ClassifierMixin):
:rtype: list :rtype: list
""" """
up = ~down up = ~down
print(self.kernel, up.shape, down.shape)
return ( return (
origin[up[:, 0]] if any(up) else None, origin[up[:, 0]] if any(up) else None,
origin[down[:, 0]] if any(down) else None, origin[down[:, 0]] if any(down) else None,
@@ -187,7 +173,12 @@ class Stree(BaseEstimator, ClassifierMixin):
the hyperplane of the node the hyperplane of the node
:rtype: np.array :rtype: np.array
""" """
return np.expand_dims(node._clf.decision_function(data), 1) res = node._clf.decision_function(data)
if res.ndim == 1:
return np.expand_dims(res, 1)
elif res.shape[1] > 1:
res = np.delete(res, slice(1, res.shape[1]), axis=1)
return res
def _split_criteria(self, data: np.array) -> np.array: def _split_criteria(self, data: np.array) -> np.array:
"""Set the criteria to split arrays """Set the criteria to split arrays
@@ -256,9 +247,8 @@ class Stree(BaseEstimator, ClassifierMixin):
run_tree(self.tree_) run_tree(self.tree_)
def _build_clf(self): def _build_clf(self):
""" Select the correct classifier for the node """ Build the correct classifier for the node
""" """
return ( return (
LinearSVC( LinearSVC(
max_iter=self.max_iter, max_iter=self.max_iter,
@@ -272,6 +262,7 @@ class Stree(BaseEstimator, ClassifierMixin):
max_iter=self.max_iter, max_iter=self.max_iter,
tol=self.tol, tol=self.tol,
C=self.C, C=self.C,
gamma=self.gamma,
) )
) )

View File

@@ -41,6 +41,9 @@ class Snode_graph(Snode):
def set_axis_limits(self, limits: tuple): def set_axis_limits(self, limits: tuple):
self._xlimits, self._ylimits, self._zlimits = limits self._xlimits, self._ylimits, self._zlimits = limits
def get_axis_limits(self) -> tuple:
return self._xlimits, self._ylimits, self._zlimits
def _set_graphics_axis(self, ax: Axes3D): def _set_graphics_axis(self, ax: Axes3D):
ax.set_xlim(self._xlimits) ax.set_xlim(self._xlimits)
ax.set_ylim(self._ylimits) ax.set_ylim(self._ylimits)
@@ -50,7 +53,7 @@ class Snode_graph(Snode):
self, save_folder: str = "./", save_prefix: str = "", save_seq: int = 1 self, save_folder: str = "./", save_prefix: str = "", save_seq: int = 1
): ):
_, fig = self.plot_hyperplane() _, fig = self.plot_hyperplane()
name = f"{save_folder}{save_prefix}STnode{save_seq}.png" name = os.path.join(save_folder, f"{save_prefix}STnode{save_seq}.png")
fig.savefig(name, bbox_inches="tight") fig.savefig(name, bbox_inches="tight")
plt.close(fig) plt.close(fig)

View File

@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
import warnings import warnings
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
from stree import Stree_grapher, Snode_graph from stree import Stree_grapher, Snode_graph, Snode
def get_dataset(random_state=0, n_features=3): def get_dataset(random_state=0, n_features=3):
@@ -30,18 +30,14 @@ def get_dataset(random_state=0, n_features=3):
class Stree_grapher_test(unittest.TestCase): class Stree_grapher_test(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
os.environ["TESTING"] = "1"
self._random_state = 1 self._random_state = 1
self._clf = Stree_grapher(dict(random_state=self._random_state)) self._clf = Stree_grapher(dict(random_state=self._random_state))
self._clf.fit(*get_dataset(self._random_state, n_features=4)) self._clf.fit(*get_dataset(self._random_state, n_features=4))
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@classmethod @classmethod
def tearDownClass(cls): def setUp(cls):
try: os.environ["TESTING"] = "1"
os.environ.pop("TESTING")
except KeyError:
pass
def test_iterator(self): def test_iterator(self):
"""Check preorder iterator """Check preorder iterator
@@ -73,7 +69,9 @@ class Stree_grapher_test(unittest.TestCase):
self.assertGreater(accuracy_score, 0.86) self.assertGreater(accuracy_score, 0.86)
def test_save_all(self): def test_save_all(self):
folder_name = "/tmp/" folder_name = os.path.join(os.sep, "tmp", "stree")
if os.path.isdir(folder_name):
os.rmdir(folder_name)
file_names = [ file_names = [
os.path.join(folder_name, f"STnode{i}.png") for i in range(1, 8) os.path.join(folder_name, f"STnode{i}.png") for i in range(1, 8)
] ]
@@ -85,6 +83,7 @@ class Stree_grapher_test(unittest.TestCase):
self.assertTrue(os.path.exists(file_name)) self.assertTrue(os.path.exists(file_name))
self.assertEqual("png", imghdr.what(file_name)) self.assertEqual("png", imghdr.what(file_name))
os.remove(file_name) os.remove(file_name)
os.rmdir(folder_name)
def test_plot_all(self): def test_plot_all(self):
with warnings.catch_warnings(): with warnings.catch_warnings():
@@ -98,20 +97,14 @@ class Stree_grapher_test(unittest.TestCase):
class Snode_graph_test(unittest.TestCase): class Snode_graph_test(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
os.environ["TESTING"] = "1"
self._random_state = 1 self._random_state = 1
self._clf = Stree_grapher(dict(random_state=self._random_state)) self._clf = Stree_grapher(dict(random_state=self._random_state))
self._clf.fit(*get_dataset(self._random_state)) self._clf.fit(*get_dataset(self._random_state))
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@classmethod @classmethod
def tearDownClass(cls): def setUp(cls):
"""Remove the testing environ variable os.environ["TESTING"] = "1"
"""
try:
os.environ.pop("TESTING")
except KeyError:
pass
def test_plot_size(self): def test_plot_size(self):
default = self._clf._tree_gr.get_plot_size() default = self._clf._tree_gr.get_plot_size()
@@ -205,3 +198,14 @@ class Snode_graph_test(unittest.TestCase):
self._clf._tree_gr.plot_distribution() self._clf._tree_gr.plot_distribution()
num_figures_after = plt.gcf().number num_figures_after = plt.gcf().number
self.assertEqual(1, num_figures_after - num_figures_before) self.assertEqual(1, num_figures_after - num_figures_before)
def test_set_axis_limits(self):
node = Snode_graph(Snode(None, None, None, "test"))
limits = (-2, 2), (-3, 3), (-4, 4)
node.set_axis_limits(limits)
computed = node.get_axis_limits()
x, y, z = limits
xx, yy, zz = computed
self.assertEqual(x, xx)
self.assertEqual(y, yy)
self.assertEqual(z, zz)

View File

@@ -26,17 +26,13 @@ def get_dataset(random_state=0):
class Stree_test(unittest.TestCase): class Stree_test(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
os.environ["TESTING"] = "1"
self._random_state = 1 self._random_state = 1
self._kernels = ["linear", "rbf", "poly"] self._kernels = ["linear", "rbf", "poly"]
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@classmethod @classmethod
def tearDownClass(cls): def setUp(cls):
try: os.environ["TESTING"] = "1"
os.environ.pop("TESTING")
except KeyError:
pass
def _check_tree(self, node: Snode): def _check_tree(self, node: Snode):
"""Check recursively that the nodes that are not leaves have the """Check recursively that the nodes that are not leaves have the
@@ -79,6 +75,9 @@ class Stree_test(unittest.TestCase):
def test_build_tree(self): def test_build_tree(self):
"""Check if the tree is built the same way as predictions of models """Check if the tree is built the same way as predictions of models
""" """
import warnings
warnings.filterwarnings("ignore")
for kernel in self._kernels: for kernel in self._kernels:
clf = Stree(kernel=kernel, random_state=self._random_state) clf = Stree(kernel=kernel, random_state=self._random_state)
clf.fit(*get_dataset(self._random_state)) clf.fit(*get_dataset(self._random_state))
@@ -260,20 +259,14 @@ class Stree_test(unittest.TestCase):
class Snode_test(unittest.TestCase): class Snode_test(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
os.environ["TESTING"] = "1"
self._random_state = 1 self._random_state = 1
self._clf = Stree(random_state=self._random_state) self._clf = Stree(random_state=self._random_state)
self._clf.fit(*get_dataset(self._random_state)) self._clf.fit(*get_dataset(self._random_state))
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@classmethod @classmethod
def tearDownClass(cls): def setUp(cls):
"""[summary] os.environ["TESTING"] = "1"
"""
try:
os.environ.pop("TESTING")
except KeyError:
pass
def test_attributes_in_leaves(self): def test_attributes_in_leaves(self):
"""Check if the attributes in leaves have correct values so they form a """Check if the attributes in leaves have correct values so they form a