mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-18 00:46:02 +00:00
#3 Complete multiclass in Stree
Add multiclass dimensions management in distances method Add gamma hyperparameter for non linear kernels
This commit is contained in:
@@ -126,6 +126,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
random_state: int = None,
|
random_state: int = None,
|
||||||
max_depth: int = None,
|
max_depth: int = None,
|
||||||
tol: float = 1e-4,
|
tol: float = 1e-4,
|
||||||
|
gamma="scale",
|
||||||
min_samples_split: int = 0,
|
min_samples_split: int = 0,
|
||||||
):
|
):
|
||||||
self.max_iter = max_iter
|
self.max_iter = max_iter
|
||||||
@@ -134,6 +135,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
self.max_depth = max_depth
|
self.max_depth = max_depth
|
||||||
self.tol = tol
|
self.tol = tol
|
||||||
|
self.gamma = gamma
|
||||||
self.min_samples_split = min_samples_split
|
self.min_samples_split = min_samples_split
|
||||||
|
|
||||||
def _more_tags(self) -> dict:
|
def _more_tags(self) -> dict:
|
||||||
@@ -144,21 +146,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
"""
|
"""
|
||||||
return {"binary_only": True, "requires_y": True}
|
return {"binary_only": True, "requires_y": True}
|
||||||
|
|
||||||
def _linear_function(self, data: np.array, node: Snode) -> np.array:
|
|
||||||
"""Compute the distance of set of samples to a hyperplane, in
|
|
||||||
multiclass classification it should compute the distance to a
|
|
||||||
hyperplane of each class
|
|
||||||
|
|
||||||
:param data: dataset of samples
|
|
||||||
:type data: np.array shape(m, n)
|
|
||||||
:param node: the node that contains the hyperplance coefficients
|
|
||||||
:type node: Snode shape(1, n)
|
|
||||||
:return: array of distances of each sample to the hyperplane
|
|
||||||
:rtype: np.array
|
|
||||||
"""
|
|
||||||
coef = node._clf.coef_[0, :].reshape(-1, data.shape[1])
|
|
||||||
return data.dot(coef.T) + node._clf.intercept_[0]
|
|
||||||
|
|
||||||
def _split_array(self, origin: np.array, down: np.array) -> list:
|
def _split_array(self, origin: np.array, down: np.array) -> list:
|
||||||
"""Split an array in two based on indices passed as down and its complement
|
"""Split an array in two based on indices passed as down and its complement
|
||||||
|
|
||||||
@@ -170,7 +157,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
:rtype: list
|
:rtype: list
|
||||||
"""
|
"""
|
||||||
up = ~down
|
up = ~down
|
||||||
print(self.kernel, up.shape, down.shape)
|
|
||||||
return (
|
return (
|
||||||
origin[up[:, 0]] if any(up) else None,
|
origin[up[:, 0]] if any(up) else None,
|
||||||
origin[down[:, 0]] if any(down) else None,
|
origin[down[:, 0]] if any(down) else None,
|
||||||
@@ -187,7 +173,12 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
the hyperplane of the node
|
the hyperplane of the node
|
||||||
:rtype: np.array
|
:rtype: np.array
|
||||||
"""
|
"""
|
||||||
return np.expand_dims(node._clf.decision_function(data), 1)
|
res = node._clf.decision_function(data)
|
||||||
|
if res.ndim == 1:
|
||||||
|
return np.expand_dims(res, 1)
|
||||||
|
elif res.shape[1] > 1:
|
||||||
|
res = np.delete(res, slice(1, res.shape[1]), axis=1)
|
||||||
|
return res
|
||||||
|
|
||||||
def _split_criteria(self, data: np.array) -> np.array:
|
def _split_criteria(self, data: np.array) -> np.array:
|
||||||
"""Set the criteria to split arrays
|
"""Set the criteria to split arrays
|
||||||
@@ -256,9 +247,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
run_tree(self.tree_)
|
run_tree(self.tree_)
|
||||||
|
|
||||||
def _build_clf(self):
|
def _build_clf(self):
|
||||||
""" Select the correct classifier for the node
|
""" Build the correct classifier for the node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return (
|
return (
|
||||||
LinearSVC(
|
LinearSVC(
|
||||||
max_iter=self.max_iter,
|
max_iter=self.max_iter,
|
||||||
@@ -272,6 +262,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
max_iter=self.max_iter,
|
max_iter=self.max_iter,
|
||||||
tol=self.tol,
|
tol=self.tol,
|
||||||
C=self.C,
|
C=self.C,
|
||||||
|
gamma=self.gamma,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -41,6 +41,9 @@ class Snode_graph(Snode):
|
|||||||
def set_axis_limits(self, limits: tuple):
|
def set_axis_limits(self, limits: tuple):
|
||||||
self._xlimits, self._ylimits, self._zlimits = limits
|
self._xlimits, self._ylimits, self._zlimits = limits
|
||||||
|
|
||||||
|
def get_axis_limits(self) -> tuple:
|
||||||
|
return self._xlimits, self._ylimits, self._zlimits
|
||||||
|
|
||||||
def _set_graphics_axis(self, ax: Axes3D):
|
def _set_graphics_axis(self, ax: Axes3D):
|
||||||
ax.set_xlim(self._xlimits)
|
ax.set_xlim(self._xlimits)
|
||||||
ax.set_ylim(self._ylimits)
|
ax.set_ylim(self._ylimits)
|
||||||
@@ -50,7 +53,7 @@ class Snode_graph(Snode):
|
|||||||
self, save_folder: str = "./", save_prefix: str = "", save_seq: int = 1
|
self, save_folder: str = "./", save_prefix: str = "", save_seq: int = 1
|
||||||
):
|
):
|
||||||
_, fig = self.plot_hyperplane()
|
_, fig = self.plot_hyperplane()
|
||||||
name = f"{save_folder}{save_prefix}STnode{save_seq}.png"
|
name = os.path.join(save_folder, f"{save_prefix}STnode{save_seq}.png")
|
||||||
fig.savefig(name, bbox_inches="tight")
|
fig.savefig(name, bbox_inches="tight")
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
|
|||||||
import warnings
|
import warnings
|
||||||
from sklearn.datasets import make_classification
|
from sklearn.datasets import make_classification
|
||||||
|
|
||||||
from stree import Stree_grapher, Snode_graph
|
from stree import Stree_grapher, Snode_graph, Snode
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(random_state=0, n_features=3):
|
def get_dataset(random_state=0, n_features=3):
|
||||||
@@ -30,18 +30,14 @@ def get_dataset(random_state=0, n_features=3):
|
|||||||
|
|
||||||
class Stree_grapher_test(unittest.TestCase):
|
class Stree_grapher_test(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
os.environ["TESTING"] = "1"
|
|
||||||
self._random_state = 1
|
self._random_state = 1
|
||||||
self._clf = Stree_grapher(dict(random_state=self._random_state))
|
self._clf = Stree_grapher(dict(random_state=self._random_state))
|
||||||
self._clf.fit(*get_dataset(self._random_state, n_features=4))
|
self._clf.fit(*get_dataset(self._random_state, n_features=4))
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def setUp(cls):
|
||||||
try:
|
os.environ["TESTING"] = "1"
|
||||||
os.environ.pop("TESTING")
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_iterator(self):
|
def test_iterator(self):
|
||||||
"""Check preorder iterator
|
"""Check preorder iterator
|
||||||
@@ -73,7 +69,9 @@ class Stree_grapher_test(unittest.TestCase):
|
|||||||
self.assertGreater(accuracy_score, 0.86)
|
self.assertGreater(accuracy_score, 0.86)
|
||||||
|
|
||||||
def test_save_all(self):
|
def test_save_all(self):
|
||||||
folder_name = "/tmp/"
|
folder_name = os.path.join(os.sep, "tmp", "stree")
|
||||||
|
if os.path.isdir(folder_name):
|
||||||
|
os.rmdir(folder_name)
|
||||||
file_names = [
|
file_names = [
|
||||||
os.path.join(folder_name, f"STnode{i}.png") for i in range(1, 8)
|
os.path.join(folder_name, f"STnode{i}.png") for i in range(1, 8)
|
||||||
]
|
]
|
||||||
@@ -85,6 +83,7 @@ class Stree_grapher_test(unittest.TestCase):
|
|||||||
self.assertTrue(os.path.exists(file_name))
|
self.assertTrue(os.path.exists(file_name))
|
||||||
self.assertEqual("png", imghdr.what(file_name))
|
self.assertEqual("png", imghdr.what(file_name))
|
||||||
os.remove(file_name)
|
os.remove(file_name)
|
||||||
|
os.rmdir(folder_name)
|
||||||
|
|
||||||
def test_plot_all(self):
|
def test_plot_all(self):
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@@ -98,20 +97,14 @@ class Stree_grapher_test(unittest.TestCase):
|
|||||||
|
|
||||||
class Snode_graph_test(unittest.TestCase):
|
class Snode_graph_test(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
os.environ["TESTING"] = "1"
|
|
||||||
self._random_state = 1
|
self._random_state = 1
|
||||||
self._clf = Stree_grapher(dict(random_state=self._random_state))
|
self._clf = Stree_grapher(dict(random_state=self._random_state))
|
||||||
self._clf.fit(*get_dataset(self._random_state))
|
self._clf.fit(*get_dataset(self._random_state))
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def setUp(cls):
|
||||||
"""Remove the testing environ variable
|
os.environ["TESTING"] = "1"
|
||||||
"""
|
|
||||||
try:
|
|
||||||
os.environ.pop("TESTING")
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_plot_size(self):
|
def test_plot_size(self):
|
||||||
default = self._clf._tree_gr.get_plot_size()
|
default = self._clf._tree_gr.get_plot_size()
|
||||||
@@ -205,3 +198,14 @@ class Snode_graph_test(unittest.TestCase):
|
|||||||
self._clf._tree_gr.plot_distribution()
|
self._clf._tree_gr.plot_distribution()
|
||||||
num_figures_after = plt.gcf().number
|
num_figures_after = plt.gcf().number
|
||||||
self.assertEqual(1, num_figures_after - num_figures_before)
|
self.assertEqual(1, num_figures_after - num_figures_before)
|
||||||
|
|
||||||
|
def test_set_axis_limits(self):
|
||||||
|
node = Snode_graph(Snode(None, None, None, "test"))
|
||||||
|
limits = (-2, 2), (-3, 3), (-4, 4)
|
||||||
|
node.set_axis_limits(limits)
|
||||||
|
computed = node.get_axis_limits()
|
||||||
|
x, y, z = limits
|
||||||
|
xx, yy, zz = computed
|
||||||
|
self.assertEqual(x, xx)
|
||||||
|
self.assertEqual(y, yy)
|
||||||
|
self.assertEqual(z, zz)
|
||||||
|
@@ -26,17 +26,13 @@ def get_dataset(random_state=0):
|
|||||||
|
|
||||||
class Stree_test(unittest.TestCase):
|
class Stree_test(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
os.environ["TESTING"] = "1"
|
|
||||||
self._random_state = 1
|
self._random_state = 1
|
||||||
self._kernels = ["linear", "rbf", "poly"]
|
self._kernels = ["linear", "rbf", "poly"]
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def setUp(cls):
|
||||||
try:
|
os.environ["TESTING"] = "1"
|
||||||
os.environ.pop("TESTING")
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _check_tree(self, node: Snode):
|
def _check_tree(self, node: Snode):
|
||||||
"""Check recursively that the nodes that are not leaves have the
|
"""Check recursively that the nodes that are not leaves have the
|
||||||
@@ -79,6 +75,9 @@ class Stree_test(unittest.TestCase):
|
|||||||
def test_build_tree(self):
|
def test_build_tree(self):
|
||||||
"""Check if the tree is built the same way as predictions of models
|
"""Check if the tree is built the same way as predictions of models
|
||||||
"""
|
"""
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
for kernel in self._kernels:
|
for kernel in self._kernels:
|
||||||
clf = Stree(kernel=kernel, random_state=self._random_state)
|
clf = Stree(kernel=kernel, random_state=self._random_state)
|
||||||
clf.fit(*get_dataset(self._random_state))
|
clf.fit(*get_dataset(self._random_state))
|
||||||
@@ -260,20 +259,14 @@ class Stree_test(unittest.TestCase):
|
|||||||
|
|
||||||
class Snode_test(unittest.TestCase):
|
class Snode_test(unittest.TestCase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
os.environ["TESTING"] = "1"
|
|
||||||
self._random_state = 1
|
self._random_state = 1
|
||||||
self._clf = Stree(random_state=self._random_state)
|
self._clf = Stree(random_state=self._random_state)
|
||||||
self._clf.fit(*get_dataset(self._random_state))
|
self._clf.fit(*get_dataset(self._random_state))
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def setUp(cls):
|
||||||
"""[summary]
|
os.environ["TESTING"] = "1"
|
||||||
"""
|
|
||||||
try:
|
|
||||||
os.environ.pop("TESTING")
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_attributes_in_leaves(self):
|
def test_attributes_in_leaves(self):
|
||||||
"""Check if the attributes in leaves have correct values so they form a
|
"""Check if the attributes in leaves have correct values so they form a
|
||||||
|
Reference in New Issue
Block a user