#3 Complete multiclass in Stree

Add multiclass dimensions management in distances method
Add gamma hyperparameter for non linear kernels
This commit is contained in:
2020-06-08 13:54:24 +02:00
parent 3a48d8b405
commit d7c0bc3bc5
4 changed files with 41 additions and 50 deletions

View File

@@ -126,6 +126,7 @@ class Stree(BaseEstimator, ClassifierMixin):
random_state: int = None,
max_depth: int = None,
tol: float = 1e-4,
gamma="scale",
min_samples_split: int = 0,
):
self.max_iter = max_iter
@@ -134,6 +135,7 @@ class Stree(BaseEstimator, ClassifierMixin):
self.random_state = random_state
self.max_depth = max_depth
self.tol = tol
self.gamma = gamma
self.min_samples_split = min_samples_split
def _more_tags(self) -> dict:
@@ -144,21 +146,6 @@ class Stree(BaseEstimator, ClassifierMixin):
"""
return {"binary_only": True, "requires_y": True}
def _linear_function(self, data: np.array, node: Snode) -> np.array:
"""Compute the distance of set of samples to a hyperplane, in
multiclass classification it should compute the distance to a
hyperplane of each class
:param data: dataset of samples
:type data: np.array shape(m, n)
:param node: the node that contains the hyperplance coefficients
:type node: Snode shape(1, n)
:return: array of distances of each sample to the hyperplane
:rtype: np.array
"""
coef = node._clf.coef_[0, :].reshape(-1, data.shape[1])
return data.dot(coef.T) + node._clf.intercept_[0]
def _split_array(self, origin: np.array, down: np.array) -> list:
"""Split an array in two based on indices passed as down and its complement
@@ -170,7 +157,6 @@ class Stree(BaseEstimator, ClassifierMixin):
:rtype: list
"""
up = ~down
print(self.kernel, up.shape, down.shape)
return (
origin[up[:, 0]] if any(up) else None,
origin[down[:, 0]] if any(down) else None,
@@ -187,7 +173,12 @@ class Stree(BaseEstimator, ClassifierMixin):
the hyperplane of the node
:rtype: np.array
"""
return np.expand_dims(node._clf.decision_function(data), 1)
res = node._clf.decision_function(data)
if res.ndim == 1:
return np.expand_dims(res, 1)
elif res.shape[1] > 1:
res = np.delete(res, slice(1, res.shape[1]), axis=1)
return res
def _split_criteria(self, data: np.array) -> np.array:
"""Set the criteria to split arrays
@@ -256,9 +247,8 @@ class Stree(BaseEstimator, ClassifierMixin):
run_tree(self.tree_)
def _build_clf(self):
""" Select the correct classifier for the node
""" Build the correct classifier for the node
"""
return (
LinearSVC(
max_iter=self.max_iter,
@@ -272,6 +262,7 @@ class Stree(BaseEstimator, ClassifierMixin):
max_iter=self.max_iter,
tol=self.tol,
C=self.C,
gamma=self.gamma,
)
)

View File

@@ -41,6 +41,9 @@ class Snode_graph(Snode):
def set_axis_limits(self, limits: tuple):
self._xlimits, self._ylimits, self._zlimits = limits
def get_axis_limits(self) -> tuple:
return self._xlimits, self._ylimits, self._zlimits
def _set_graphics_axis(self, ax: Axes3D):
ax.set_xlim(self._xlimits)
ax.set_ylim(self._ylimits)
@@ -50,7 +53,7 @@ class Snode_graph(Snode):
self, save_folder: str = "./", save_prefix: str = "", save_seq: int = 1
):
_, fig = self.plot_hyperplane()
name = f"{save_folder}{save_prefix}STnode{save_seq}.png"
name = os.path.join(save_folder, f"{save_prefix}STnode{save_seq}.png")
fig.savefig(name, bbox_inches="tight")
plt.close(fig)

View File

@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
import warnings
from sklearn.datasets import make_classification
from stree import Stree_grapher, Snode_graph
from stree import Stree_grapher, Snode_graph, Snode
def get_dataset(random_state=0, n_features=3):
@@ -30,18 +30,14 @@ def get_dataset(random_state=0, n_features=3):
class Stree_grapher_test(unittest.TestCase):
def __init__(self, *args, **kwargs):
os.environ["TESTING"] = "1"
self._random_state = 1
self._clf = Stree_grapher(dict(random_state=self._random_state))
self._clf.fit(*get_dataset(self._random_state, n_features=4))
super().__init__(*args, **kwargs)
@classmethod
def tearDownClass(cls):
try:
os.environ.pop("TESTING")
except KeyError:
pass
def setUp(cls):
os.environ["TESTING"] = "1"
def test_iterator(self):
"""Check preorder iterator
@@ -73,7 +69,9 @@ class Stree_grapher_test(unittest.TestCase):
self.assertGreater(accuracy_score, 0.86)
def test_save_all(self):
folder_name = "/tmp/"
folder_name = os.path.join(os.sep, "tmp", "stree")
if os.path.isdir(folder_name):
os.rmdir(folder_name)
file_names = [
os.path.join(folder_name, f"STnode{i}.png") for i in range(1, 8)
]
@@ -85,6 +83,7 @@ class Stree_grapher_test(unittest.TestCase):
self.assertTrue(os.path.exists(file_name))
self.assertEqual("png", imghdr.what(file_name))
os.remove(file_name)
os.rmdir(folder_name)
def test_plot_all(self):
with warnings.catch_warnings():
@@ -98,20 +97,14 @@ class Stree_grapher_test(unittest.TestCase):
class Snode_graph_test(unittest.TestCase):
def __init__(self, *args, **kwargs):
os.environ["TESTING"] = "1"
self._random_state = 1
self._clf = Stree_grapher(dict(random_state=self._random_state))
self._clf.fit(*get_dataset(self._random_state))
super().__init__(*args, **kwargs)
@classmethod
def tearDownClass(cls):
"""Remove the testing environ variable
"""
try:
os.environ.pop("TESTING")
except KeyError:
pass
def setUp(cls):
os.environ["TESTING"] = "1"
def test_plot_size(self):
default = self._clf._tree_gr.get_plot_size()
@@ -205,3 +198,14 @@ class Snode_graph_test(unittest.TestCase):
self._clf._tree_gr.plot_distribution()
num_figures_after = plt.gcf().number
self.assertEqual(1, num_figures_after - num_figures_before)
def test_set_axis_limits(self):
node = Snode_graph(Snode(None, None, None, "test"))
limits = (-2, 2), (-3, 3), (-4, 4)
node.set_axis_limits(limits)
computed = node.get_axis_limits()
x, y, z = limits
xx, yy, zz = computed
self.assertEqual(x, xx)
self.assertEqual(y, yy)
self.assertEqual(z, zz)

View File

@@ -26,17 +26,13 @@ def get_dataset(random_state=0):
class Stree_test(unittest.TestCase):
def __init__(self, *args, **kwargs):
os.environ["TESTING"] = "1"
self._random_state = 1
self._kernels = ["linear", "rbf", "poly"]
super().__init__(*args, **kwargs)
@classmethod
def tearDownClass(cls):
try:
os.environ.pop("TESTING")
except KeyError:
pass
def setUp(cls):
os.environ["TESTING"] = "1"
def _check_tree(self, node: Snode):
"""Check recursively that the nodes that are not leaves have the
@@ -79,6 +75,9 @@ class Stree_test(unittest.TestCase):
def test_build_tree(self):
"""Check if the tree is built the same way as predictions of models
"""
import warnings
warnings.filterwarnings("ignore")
for kernel in self._kernels:
clf = Stree(kernel=kernel, random_state=self._random_state)
clf.fit(*get_dataset(self._random_state))
@@ -260,20 +259,14 @@ class Stree_test(unittest.TestCase):
class Snode_test(unittest.TestCase):
def __init__(self, *args, **kwargs):
os.environ["TESTING"] = "1"
self._random_state = 1
self._clf = Stree(random_state=self._random_state)
self._clf.fit(*get_dataset(self._random_state))
super().__init__(*args, **kwargs)
@classmethod
def tearDownClass(cls):
"""[summary]
"""
try:
os.environ.pop("TESTING")
except KeyError:
pass
def setUp(cls):
os.environ["TESTING"] = "1"
def test_attributes_in_leaves(self):
"""Check if the attributes in leaves have correct values so they form a