mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
#3 Complete multiclass in Stree
Add multiclass dimensions management in distances method Add gamma hyperparameter for non linear kernels
This commit is contained in:
@@ -126,6 +126,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
random_state: int = None,
|
||||
max_depth: int = None,
|
||||
tol: float = 1e-4,
|
||||
gamma="scale",
|
||||
min_samples_split: int = 0,
|
||||
):
|
||||
self.max_iter = max_iter
|
||||
@@ -134,6 +135,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.random_state = random_state
|
||||
self.max_depth = max_depth
|
||||
self.tol = tol
|
||||
self.gamma = gamma
|
||||
self.min_samples_split = min_samples_split
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
@@ -144,21 +146,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
return {"binary_only": True, "requires_y": True}
|
||||
|
||||
def _linear_function(self, data: np.array, node: Snode) -> np.array:
|
||||
"""Compute the distance of set of samples to a hyperplane, in
|
||||
multiclass classification it should compute the distance to a
|
||||
hyperplane of each class
|
||||
|
||||
:param data: dataset of samples
|
||||
:type data: np.array shape(m, n)
|
||||
:param node: the node that contains the hyperplance coefficients
|
||||
:type node: Snode shape(1, n)
|
||||
:return: array of distances of each sample to the hyperplane
|
||||
:rtype: np.array
|
||||
"""
|
||||
coef = node._clf.coef_[0, :].reshape(-1, data.shape[1])
|
||||
return data.dot(coef.T) + node._clf.intercept_[0]
|
||||
|
||||
def _split_array(self, origin: np.array, down: np.array) -> list:
|
||||
"""Split an array in two based on indices passed as down and its complement
|
||||
|
||||
@@ -170,7 +157,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
:rtype: list
|
||||
"""
|
||||
up = ~down
|
||||
print(self.kernel, up.shape, down.shape)
|
||||
return (
|
||||
origin[up[:, 0]] if any(up) else None,
|
||||
origin[down[:, 0]] if any(down) else None,
|
||||
@@ -187,7 +173,12 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
the hyperplane of the node
|
||||
:rtype: np.array
|
||||
"""
|
||||
return np.expand_dims(node._clf.decision_function(data), 1)
|
||||
res = node._clf.decision_function(data)
|
||||
if res.ndim == 1:
|
||||
return np.expand_dims(res, 1)
|
||||
elif res.shape[1] > 1:
|
||||
res = np.delete(res, slice(1, res.shape[1]), axis=1)
|
||||
return res
|
||||
|
||||
def _split_criteria(self, data: np.array) -> np.array:
|
||||
"""Set the criteria to split arrays
|
||||
@@ -256,9 +247,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
run_tree(self.tree_)
|
||||
|
||||
def _build_clf(self):
|
||||
""" Select the correct classifier for the node
|
||||
""" Build the correct classifier for the node
|
||||
"""
|
||||
|
||||
return (
|
||||
LinearSVC(
|
||||
max_iter=self.max_iter,
|
||||
@@ -272,6 +262,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
max_iter=self.max_iter,
|
||||
tol=self.tol,
|
||||
C=self.C,
|
||||
gamma=self.gamma,
|
||||
)
|
||||
)
|
||||
|
||||
|
@@ -41,6 +41,9 @@ class Snode_graph(Snode):
|
||||
def set_axis_limits(self, limits: tuple):
|
||||
self._xlimits, self._ylimits, self._zlimits = limits
|
||||
|
||||
def get_axis_limits(self) -> tuple:
|
||||
return self._xlimits, self._ylimits, self._zlimits
|
||||
|
||||
def _set_graphics_axis(self, ax: Axes3D):
|
||||
ax.set_xlim(self._xlimits)
|
||||
ax.set_ylim(self._ylimits)
|
||||
@@ -50,7 +53,7 @@ class Snode_graph(Snode):
|
||||
self, save_folder: str = "./", save_prefix: str = "", save_seq: int = 1
|
||||
):
|
||||
_, fig = self.plot_hyperplane()
|
||||
name = f"{save_folder}{save_prefix}STnode{save_seq}.png"
|
||||
name = os.path.join(save_folder, f"{save_prefix}STnode{save_seq}.png")
|
||||
fig.savefig(name, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
|
||||
import warnings
|
||||
from sklearn.datasets import make_classification
|
||||
|
||||
from stree import Stree_grapher, Snode_graph
|
||||
from stree import Stree_grapher, Snode_graph, Snode
|
||||
|
||||
|
||||
def get_dataset(random_state=0, n_features=3):
|
||||
@@ -30,18 +30,14 @@ def get_dataset(random_state=0, n_features=3):
|
||||
|
||||
class Stree_grapher_test(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
os.environ["TESTING"] = "1"
|
||||
self._random_state = 1
|
||||
self._clf = Stree_grapher(dict(random_state=self._random_state))
|
||||
self._clf.fit(*get_dataset(self._random_state, n_features=4))
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
os.environ.pop("TESTING")
|
||||
except KeyError:
|
||||
pass
|
||||
def setUp(cls):
|
||||
os.environ["TESTING"] = "1"
|
||||
|
||||
def test_iterator(self):
|
||||
"""Check preorder iterator
|
||||
@@ -73,7 +69,9 @@ class Stree_grapher_test(unittest.TestCase):
|
||||
self.assertGreater(accuracy_score, 0.86)
|
||||
|
||||
def test_save_all(self):
|
||||
folder_name = "/tmp/"
|
||||
folder_name = os.path.join(os.sep, "tmp", "stree")
|
||||
if os.path.isdir(folder_name):
|
||||
os.rmdir(folder_name)
|
||||
file_names = [
|
||||
os.path.join(folder_name, f"STnode{i}.png") for i in range(1, 8)
|
||||
]
|
||||
@@ -85,6 +83,7 @@ class Stree_grapher_test(unittest.TestCase):
|
||||
self.assertTrue(os.path.exists(file_name))
|
||||
self.assertEqual("png", imghdr.what(file_name))
|
||||
os.remove(file_name)
|
||||
os.rmdir(folder_name)
|
||||
|
||||
def test_plot_all(self):
|
||||
with warnings.catch_warnings():
|
||||
@@ -98,20 +97,14 @@ class Stree_grapher_test(unittest.TestCase):
|
||||
|
||||
class Snode_graph_test(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
os.environ["TESTING"] = "1"
|
||||
self._random_state = 1
|
||||
self._clf = Stree_grapher(dict(random_state=self._random_state))
|
||||
self._clf.fit(*get_dataset(self._random_state))
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""Remove the testing environ variable
|
||||
"""
|
||||
try:
|
||||
os.environ.pop("TESTING")
|
||||
except KeyError:
|
||||
pass
|
||||
def setUp(cls):
|
||||
os.environ["TESTING"] = "1"
|
||||
|
||||
def test_plot_size(self):
|
||||
default = self._clf._tree_gr.get_plot_size()
|
||||
@@ -205,3 +198,14 @@ class Snode_graph_test(unittest.TestCase):
|
||||
self._clf._tree_gr.plot_distribution()
|
||||
num_figures_after = plt.gcf().number
|
||||
self.assertEqual(1, num_figures_after - num_figures_before)
|
||||
|
||||
def test_set_axis_limits(self):
|
||||
node = Snode_graph(Snode(None, None, None, "test"))
|
||||
limits = (-2, 2), (-3, 3), (-4, 4)
|
||||
node.set_axis_limits(limits)
|
||||
computed = node.get_axis_limits()
|
||||
x, y, z = limits
|
||||
xx, yy, zz = computed
|
||||
self.assertEqual(x, xx)
|
||||
self.assertEqual(y, yy)
|
||||
self.assertEqual(z, zz)
|
||||
|
@@ -26,17 +26,13 @@ def get_dataset(random_state=0):
|
||||
|
||||
class Stree_test(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
os.environ["TESTING"] = "1"
|
||||
self._random_state = 1
|
||||
self._kernels = ["linear", "rbf", "poly"]
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
os.environ.pop("TESTING")
|
||||
except KeyError:
|
||||
pass
|
||||
def setUp(cls):
|
||||
os.environ["TESTING"] = "1"
|
||||
|
||||
def _check_tree(self, node: Snode):
|
||||
"""Check recursively that the nodes that are not leaves have the
|
||||
@@ -79,6 +75,9 @@ class Stree_test(unittest.TestCase):
|
||||
def test_build_tree(self):
|
||||
"""Check if the tree is built the same way as predictions of models
|
||||
"""
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
for kernel in self._kernels:
|
||||
clf = Stree(kernel=kernel, random_state=self._random_state)
|
||||
clf.fit(*get_dataset(self._random_state))
|
||||
@@ -260,20 +259,14 @@ class Stree_test(unittest.TestCase):
|
||||
|
||||
class Snode_test(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
os.environ["TESTING"] = "1"
|
||||
self._random_state = 1
|
||||
self._clf = Stree(random_state=self._random_state)
|
||||
self._clf.fit(*get_dataset(self._random_state))
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""[summary]
|
||||
"""
|
||||
try:
|
||||
os.environ.pop("TESTING")
|
||||
except KeyError:
|
||||
pass
|
||||
def setUp(cls):
|
||||
os.environ["TESTING"] = "1"
|
||||
|
||||
def test_attributes_in_leaves(self):
|
||||
"""Check if the attributes in leaves have correct values so they form a
|
||||
|
Reference in New Issue
Block a user