mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 16:36:01 +00:00
#15 First approach
Create impurity function in Stree (consistent name, same criteria as other splitter parameter) Create test for the new function Update init test Update test splitter parameters Rename old impurity function to partition_impurity
This commit is contained in:
2
setup.py
2
setup.py
@@ -25,7 +25,7 @@ setuptools.setup(
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Natural Language :: English",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Intended Audience :: Science/Research",
|
||||
|
@@ -120,8 +120,7 @@ class Snode:
|
||||
|
||||
|
||||
class Siterator:
|
||||
"""Stree preorder iterator
|
||||
"""
|
||||
"""Stree preorder iterator"""
|
||||
|
||||
def __init__(self, tree: Snode):
|
||||
self._stack = []
|
||||
@@ -167,20 +166,22 @@ class Splitter:
|
||||
f"criterion must be gini or entropy got({criterion})"
|
||||
)
|
||||
|
||||
if criteria not in ["min_distance", "max_samples", "max_distance"]:
|
||||
if criteria not in [
|
||||
"max_samples",
|
||||
"impurity",
|
||||
]:
|
||||
raise ValueError(
|
||||
"split_criteria has to be min_distance "
|
||||
f"max_distance or max_samples got ({criteria})"
|
||||
f"criteria has to be max_samples or impurity; got ({criteria})"
|
||||
)
|
||||
|
||||
if splitter_type not in ["random", "best"]:
|
||||
raise ValueError(
|
||||
f"splitter must be either random or best got({splitter_type})"
|
||||
f"splitter must be either random or best, got({splitter_type})"
|
||||
)
|
||||
self.criterion_function = getattr(self, f"_{self._criterion}")
|
||||
self.decision_criteria = getattr(self, f"_{self._criteria}")
|
||||
|
||||
def impurity(self, y: np.array) -> np.array:
|
||||
def partition_impurity(self, y: np.array) -> np.array:
|
||||
return self.criterion_function(y)
|
||||
|
||||
@staticmethod
|
||||
@@ -266,34 +267,13 @@ class Splitter:
|
||||
def get_subspace(
|
||||
self, dataset: np.array, labels: np.array, max_features: int
|
||||
) -> list:
|
||||
"""Return the best subspace to make a split
|
||||
"""
|
||||
"""Return the best/random subspace to make a split"""
|
||||
indices = self._get_subspaces_set(dataset, labels, max_features)
|
||||
return dataset[:, indices], indices
|
||||
|
||||
@staticmethod
|
||||
def _min_distance(data: np.array, _) -> np.array:
|
||||
"""Assign class to min distances
|
||||
def _impurity(self, data: np.array, _) -> np.array:
|
||||
"""return distances of the class whose partition has less impurity
|
||||
|
||||
return a vector of classes so partition can separate class 0 from
|
||||
the rest of classes, ie. class 0 goes to one splitted node and the
|
||||
rest of classes go to the other
|
||||
:param data: distances to hyper plane of every class
|
||||
:type data: np.array (m, n_classes)
|
||||
:param _: enable call compat with other measures
|
||||
:type _: None
|
||||
:return: vector with the class assigned to each sample
|
||||
:rtype: np.array shape (m,)
|
||||
"""
|
||||
return np.argmin(data, axis=1)
|
||||
|
||||
@staticmethod
|
||||
def _max_distance(data: np.array, _) -> np.array:
|
||||
"""Assign class to max distances
|
||||
|
||||
return a vector of classes so partition can separate class 0 from
|
||||
the rest of classes, ie. class 0 goes to one splitted node and the
|
||||
rest of classes go to the other
|
||||
:param data: distances to hyper plane of every class
|
||||
:type data: np.array (m, n_classes)
|
||||
:param _: enable call compat with other measures
|
||||
@@ -302,7 +282,18 @@ class Splitter:
|
||||
(can be 0, 1, ...)
|
||||
:rtype: np.array shape (m,)
|
||||
"""
|
||||
return np.argmax(data, axis=1)
|
||||
min_impurity = float("inf")
|
||||
selected = 0
|
||||
y = data.copy()
|
||||
y[data <= 0] = 0
|
||||
y[data > 0] = 1
|
||||
y = y.astype(int)
|
||||
for col in range(data.shape[1]):
|
||||
impurity_of_class = self.partition_impurity(y[col])
|
||||
if impurity_of_class < min_impurity:
|
||||
selected = col
|
||||
min_impurity = impurity_of_class
|
||||
return data[:, selected]
|
||||
|
||||
@staticmethod
|
||||
def _max_samples(data: np.array, y: np.array) -> np.array:
|
||||
@@ -325,12 +316,15 @@ class Splitter:
|
||||
that should go to one side of the tree (down)
|
||||
|
||||
"""
|
||||
# data contains the distances of every sample to every class hyperplane
|
||||
# array of (m, nc) nc = # classes
|
||||
data = self._distances(node, samples)
|
||||
if data.shape[0] < self._min_samples_split:
|
||||
self._down = np.ones((data.shape[0]), dtype=bool)
|
||||
return
|
||||
if data.ndim > 1:
|
||||
# split criteria for multiclass
|
||||
# Convert data to a (m, 1) array selecting values for samples
|
||||
data = self.decision_criteria(data, node._y)
|
||||
self._down = data > 0
|
||||
|
||||
@@ -342,8 +336,8 @@ class Splitter:
|
||||
:type node: Snode
|
||||
:param data: samples to find out distance to hyperplane
|
||||
:type data: np.ndarray
|
||||
:return: array of shape (m, 1) with the distances of every sample to
|
||||
the hyperplane of the node
|
||||
:return: array of shape (m, nc) with the distances of every sample to
|
||||
the hyperplane of every class. nc = # of classes
|
||||
:rtype: np.array
|
||||
"""
|
||||
return node._clf.decision_function(data[:, node._features])
|
||||
@@ -521,7 +515,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
if np.unique(y_next).shape[0] != self.n_classes_:
|
||||
sample_weight += 1e-5
|
||||
clf.fit(Xs, y, sample_weight=sample_weight)
|
||||
impurity = self.splitter_.impurity(y)
|
||||
impurity = self.splitter_.partition_impurity(y)
|
||||
node = Snode(clf, X, y, features, impurity, title, sample_weight)
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
self.splitter_.partition(X, node)
|
||||
@@ -544,8 +538,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
return node
|
||||
|
||||
def _build_predictor(self):
|
||||
"""Process the leaves to make them predictors
|
||||
"""
|
||||
"""Process the leaves to make them predictors"""
|
||||
|
||||
def run_tree(node: Snode):
|
||||
if node.is_leaf():
|
||||
@@ -557,8 +550,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
run_tree(self.tree_)
|
||||
|
||||
def _build_clf(self):
|
||||
""" Build the correct classifier for the node
|
||||
"""
|
||||
"""Build the correct classifier for the node"""
|
||||
return (
|
||||
LinearSVC(
|
||||
max_iter=self.max_iter,
|
||||
|
@@ -19,7 +19,7 @@ class Splitter_test(unittest.TestCase):
|
||||
min_samples_split=0,
|
||||
splitter_type="random",
|
||||
criterion="gini",
|
||||
criteria="min_distance",
|
||||
criteria="max_samples",
|
||||
random_state=None,
|
||||
):
|
||||
return Splitter(
|
||||
@@ -46,11 +46,7 @@ class Splitter_test(unittest.TestCase):
|
||||
_ = Splitter(clf=None)
|
||||
for splitter_type in ["best", "random"]:
|
||||
for criterion in ["gini", "entropy"]:
|
||||
for criteria in [
|
||||
"min_distance",
|
||||
"max_samples",
|
||||
"max_distance",
|
||||
]:
|
||||
for criteria in ["max_samples", "impurity"]:
|
||||
tcl = self.build(
|
||||
splitter_type=splitter_type,
|
||||
criterion=criterion,
|
||||
@@ -146,8 +142,8 @@ class Splitter_test(unittest.TestCase):
|
||||
self.assertEqual((4,), computed.shape)
|
||||
self.assertListEqual(expected.tolist(), computed.tolist())
|
||||
|
||||
def test_min_distance(self):
|
||||
tcl = self.build()
|
||||
def test_impurity(self):
|
||||
tcl = self.build(criteria="impurity")
|
||||
data = np.array(
|
||||
[
|
||||
[-0.1, 0.2, -0.3],
|
||||
@@ -156,23 +152,8 @@ class Splitter_test(unittest.TestCase):
|
||||
[0.1, 0.2, 0.3],
|
||||
]
|
||||
)
|
||||
expected = np.array([2, 2, 1, 0])
|
||||
computed = tcl._min_distance(data, None)
|
||||
self.assertEqual((4,), computed.shape)
|
||||
self.assertListEqual(expected.tolist(), computed.tolist())
|
||||
|
||||
def test_max_distance(self):
|
||||
tcl = self.build(criteria="max_distance")
|
||||
data = np.array(
|
||||
[
|
||||
[-0.1, 0.2, -0.3],
|
||||
[0.7, 0.01, -0.1],
|
||||
[0.7, -0.9, 0.5],
|
||||
[0.1, 0.2, 0.3],
|
||||
]
|
||||
)
|
||||
expected = np.array([1, 0, 0, 2])
|
||||
computed = tcl._max_distance(data, None)
|
||||
expected = np.array([-0.1, 0.7, 0.7, 0.1])
|
||||
computed = tcl._impurity(data, None)
|
||||
self.assertEqual((4,), computed.shape)
|
||||
self.assertListEqual(expected.tolist(), computed.tolist())
|
||||
|
||||
@@ -186,27 +167,22 @@ class Splitter_test(unittest.TestCase):
|
||||
|
||||
def test_splitter_parameter(self):
|
||||
expected_values = [
|
||||
[2, 3, 5, 7], # best entropy min_distance
|
||||
[0, 2, 4, 5], # best entropy max_samples
|
||||
[0, 2, 8, 12], # best entropy max_distance
|
||||
[1, 2, 5, 12], # best gini min_distance
|
||||
[0, 3, 4, 10], # best gini max_samples
|
||||
[1, 2, 9, 12], # best gini max_distance
|
||||
[3, 9, 11, 12], # random entropy min_distance
|
||||
[1, 5, 6, 9], # random entropy max_samples
|
||||
[1, 2, 4, 8], # random entropy max_distance
|
||||
[2, 6, 7, 12], # random gini min_distance
|
||||
[3, 9, 10, 11], # random gini max_samples
|
||||
[2, 5, 8, 12], # random gini max_distance
|
||||
[0, 1, 7, 9], # best entropy max_samples
|
||||
[3, 8, 10, 11], # best entropy impurity
|
||||
[0, 2, 8, 12], # best gini max_samples
|
||||
[1, 2, 5, 12], # best gini impurity
|
||||
[1, 2, 5, 10], # random entropy max_samples
|
||||
[4, 8, 9, 12], # random entropy impurity
|
||||
[3, 9, 11, 12], # random gini max_samples
|
||||
[1, 5, 6, 9], # random gini impurity
|
||||
]
|
||||
X, y = load_wine(return_X_y=True)
|
||||
rn = 0
|
||||
for splitter_type in ["best", "random"]:
|
||||
for criterion in ["entropy", "gini"]:
|
||||
for criteria in [
|
||||
"min_distance",
|
||||
"max_samples",
|
||||
"max_distance",
|
||||
"impurity",
|
||||
]:
|
||||
tcl = self.build(
|
||||
splitter_type=splitter_type,
|
||||
@@ -219,8 +195,10 @@ class Splitter_test(unittest.TestCase):
|
||||
dataset, computed = tcl.get_subspace(X, y, max_features=4)
|
||||
# print(
|
||||
# "{}, # {:7s}{:8s}{:15s}".format(
|
||||
# list(computed), splitter_type, criterion,
|
||||
# criteria,
|
||||
# list(computed),
|
||||
# splitter_type,
|
||||
# criterion,
|
||||
# criteria,
|
||||
# )
|
||||
# )
|
||||
self.assertListEqual(expected, list(computed))
|
||||
|
@@ -56,8 +56,7 @@ class Stree_test(unittest.TestCase):
|
||||
self._check_tree(node.get_up())
|
||||
|
||||
def test_build_tree(self):
|
||||
"""Check if the tree is built the same way as predictions of models
|
||||
"""
|
||||
"""Check if the tree is built the same way as predictions of models"""
|
||||
warnings.filterwarnings("ignore")
|
||||
for kernel in self._kernels:
|
||||
clf = Stree(kernel=kernel, random_state=self._random_state)
|
||||
@@ -99,8 +98,7 @@ class Stree_test(unittest.TestCase):
|
||||
self.assertListEqual(yp_line.tolist(), yp_once.tolist())
|
||||
|
||||
def test_iterator_and_str(self):
|
||||
"""Check preorder iterator
|
||||
"""
|
||||
"""Check preorder iterator"""
|
||||
expected = [
|
||||
"root feaures=(0, 1, 2) impurity=0.5000",
|
||||
"root - Down feaures=(0, 1, 2) impurity=0.0671",
|
||||
@@ -195,28 +193,22 @@ class Stree_test(unittest.TestCase):
|
||||
"max_samples linear": 0.9533333333333334,
|
||||
"max_samples rbf": 0.836,
|
||||
"max_samples poly": 0.9473333333333334,
|
||||
"min_distance linear": 0.9533333333333334,
|
||||
"min_distance rbf": 0.836,
|
||||
"min_distance poly": 0.9473333333333334,
|
||||
"max_distance linear": 0.9533333333333334,
|
||||
"max_distance rbf": 0.836,
|
||||
"max_distance poly": 0.9473333333333334,
|
||||
"impurity linear": 0.9533333333333334,
|
||||
"impurity rbf": 0.836,
|
||||
"impurity poly": 0.9473333333333334,
|
||||
},
|
||||
"Iris": {
|
||||
"max_samples linear": 0.98,
|
||||
"max_samples rbf": 1.0,
|
||||
"max_samples poly": 1.0,
|
||||
"min_distance linear": 0.98,
|
||||
"min_distance rbf": 1.0,
|
||||
"min_distance poly": 1.0,
|
||||
"max_distance linear": 0.98,
|
||||
"max_distance rbf": 1.0,
|
||||
"max_distance poly": 1.0,
|
||||
"impurity linear": 0.98,
|
||||
"impurity rbf": 1,
|
||||
"impurity poly": 1,
|
||||
},
|
||||
}
|
||||
for name, dataset in datasets.items():
|
||||
px, py = dataset
|
||||
for criteria in ["max_samples", "min_distance", "max_distance"]:
|
||||
for criteria in ["max_samples", "impurity"]:
|
||||
for kernel in self._kernels:
|
||||
clf = Stree(
|
||||
C=1e4,
|
||||
@@ -225,6 +217,7 @@ class Stree_test(unittest.TestCase):
|
||||
random_state=self._random_state,
|
||||
)
|
||||
clf.fit(px, py)
|
||||
print(f"{name} {criteria} {kernel}")
|
||||
outcome = outcomes[name][f"{criteria} {kernel}"]
|
||||
self.assertAlmostEqual(outcome, clf.score(px, py))
|
||||
|
||||
@@ -297,7 +290,10 @@ class Stree_test(unittest.TestCase):
|
||||
0.9433333333333334,
|
||||
]
|
||||
for kernel, accuracy_expected in zip(self._kernels, accuracies):
|
||||
clf = Stree(random_state=self._random_state, kernel=kernel,)
|
||||
clf = Stree(
|
||||
random_state=self._random_state,
|
||||
kernel=kernel,
|
||||
)
|
||||
clf.fit(X, y)
|
||||
accuracy_score = clf.score(X, y)
|
||||
yp = clf.predict(X)
|
||||
@@ -314,32 +310,23 @@ class Stree_test(unittest.TestCase):
|
||||
def test_score_multi_class(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
accuracies = [
|
||||
0.8258427, # Wine linear min_distance
|
||||
0.6741573, # Wine linear max_distance
|
||||
0.651685393258427, # Wine linear impurity
|
||||
0.8314607, # Wine linear max_samples
|
||||
0.6629213, # Wine rbf min_distance
|
||||
1.0000000, # Wine rbf max_distance
|
||||
0.6629213483146067, # Wine rbf impurity
|
||||
0.4044944, # Wine rbf max_samples
|
||||
0.9157303, # Wine poly min_distance
|
||||
1.0000000, # Wine poly max_distance
|
||||
0.9157303, # Wine poly impurity
|
||||
0.7640449, # Wine poly max_samples
|
||||
0.9933333, # Iris linear min_distance
|
||||
0.9666667, # Iris linear max_distance
|
||||
0.9933333, # Iris linear impurity
|
||||
0.9666667, # Iris linear max_samples
|
||||
0.9800000, # Iris rbf min_distance
|
||||
0.9800000, # Iris rbf max_distance
|
||||
0.9800000, # Iris rbf impurity
|
||||
0.9800000, # Iris rbf max_samples
|
||||
1.0000000, # Iris poly min_distance
|
||||
1.0000000, # Iris poly max_distance
|
||||
1.0000000, # Iris poly impurity
|
||||
1.0000000, # Iris poly max_samples
|
||||
0.8993333, # Synthetic linear min_distance
|
||||
0.6533333, # Synthetic linear max_distance
|
||||
0.8993333, # Synthetic linear impurity
|
||||
0.9313333, # Synthetic linear max_samples
|
||||
0.8320000, # Synthetic rbf min_distance
|
||||
0.6660000, # Synthetic rbf max_distance
|
||||
0.8320000, # Synthetic rbf impurity
|
||||
0.8320000, # Synthetic rbf max_samples
|
||||
0.6066667, # Synthetic poly min_distance
|
||||
0.6840000, # Synthetic poly max_distance
|
||||
0.6066667, # Synthetic poly impurity
|
||||
0.6340000, # Synthetic poly max_samples
|
||||
]
|
||||
datasets = [
|
||||
@@ -354,8 +341,7 @@ class Stree_test(unittest.TestCase):
|
||||
X, y = dataset
|
||||
for kernel in self._kernels:
|
||||
for criteria in [
|
||||
"min_distance",
|
||||
"max_distance",
|
||||
"impurity",
|
||||
"max_samples",
|
||||
]:
|
||||
clf = Stree(
|
||||
@@ -407,7 +393,13 @@ class Stree_test(unittest.TestCase):
|
||||
original = weights_no_zero.copy()
|
||||
clf = Stree()
|
||||
clf.fit(X, y)
|
||||
node = clf.train(X, y, weights, 1, "test",)
|
||||
node = clf.train(
|
||||
X,
|
||||
y,
|
||||
weights,
|
||||
1,
|
||||
"test",
|
||||
)
|
||||
# if a class is lost with zero weights the patch adds epsilon
|
||||
self.assertListEqual(weights.tolist(), weights_epsilon)
|
||||
self.assertListEqual(node._sample_weight.tolist(), weights_epsilon)
|
||||
|
Reference in New Issue
Block a user