mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 23:46:02 +00:00
Complete implementation of splitter_type = impurity with tests
Remove max_distance & min_distance splitter types
This commit is contained in:
@@ -57,6 +57,7 @@ class Snode:
|
||||
)
|
||||
self._features = features
|
||||
self._impurity = impurity
|
||||
self._partition_column: int = -1
|
||||
|
||||
@classmethod
|
||||
def copy(cls, node: "Snode") -> "Snode":
|
||||
@@ -69,6 +70,12 @@ class Snode:
|
||||
node._title,
|
||||
)
|
||||
|
||||
def set_partition_column(self, col: int):
|
||||
self._partition_column = col
|
||||
|
||||
def get_partition_column(self) -> int:
|
||||
return self._partition_column
|
||||
|
||||
def set_down(self, son):
|
||||
self._down = son
|
||||
|
||||
@@ -239,7 +246,7 @@ class Splitter:
|
||||
node = Snode(
|
||||
self._clf, dataset, labels, feature_set, 0.0, "subset"
|
||||
)
|
||||
self.partition(dataset, node)
|
||||
self.partition(dataset, node, train=True)
|
||||
y1, y2 = self.part(labels)
|
||||
gain = self.information_gain(labels, y1, y2)
|
||||
if gain > max_gain:
|
||||
@@ -272,7 +279,7 @@ class Splitter:
|
||||
return dataset[:, indices], indices
|
||||
|
||||
def _impurity(self, data: np.array, _) -> np.array:
|
||||
"""return distances of the class whose partition has less impurity
|
||||
"""return column of dataset to be taken into account to split dataset
|
||||
|
||||
:param data: distances to hyper plane of every class
|
||||
:type data: np.array (m, n_classes)
|
||||
@@ -289,15 +296,15 @@ class Splitter:
|
||||
y[data > 0] = 1
|
||||
y = y.astype(int)
|
||||
for col in range(data.shape[1]):
|
||||
impurity_of_class = self.partition_impurity(y[col])
|
||||
impurity_of_class = self.partition_impurity(y[:, col])
|
||||
if impurity_of_class < min_impurity:
|
||||
selected = col
|
||||
min_impurity = impurity_of_class
|
||||
return data[:, selected]
|
||||
return selected
|
||||
|
||||
@staticmethod
|
||||
def _max_samples(data: np.array, y: np.array) -> np.array:
|
||||
"""return distances of the class with more samples
|
||||
"""return column of dataset to be taken into account to split dataset
|
||||
|
||||
:param data: distances to hyper plane of every class
|
||||
:type data: np.array (m, n_classes)
|
||||
@@ -308,10 +315,9 @@ class Splitter:
|
||||
"""
|
||||
# select the class with max number of samples
|
||||
_, samples = np.unique(y, return_counts=True)
|
||||
selected = np.argmax(samples)
|
||||
return data[:, selected]
|
||||
return np.argmax(samples)
|
||||
|
||||
def partition(self, samples: np.array, node: Snode):
|
||||
def partition(self, samples: np.array, node: Snode, train: bool):
|
||||
"""Set the criteria to split arrays. Compute the indices of the samples
|
||||
that should go to one side of the tree (down)
|
||||
|
||||
@@ -325,7 +331,16 @@ class Splitter:
|
||||
if data.ndim > 1:
|
||||
# split criteria for multiclass
|
||||
# Convert data to a (m, 1) array selecting values for samples
|
||||
data = self.decision_criteria(data, node._y)
|
||||
if train:
|
||||
# in train time we have to compute the column to take into
|
||||
# account to split the dataset
|
||||
col = self.decision_criteria(data, node._y)
|
||||
node.set_partition_column(col)
|
||||
else:
|
||||
# in predcit time just use the column computed in train time
|
||||
# is taking the classifier of class <col>
|
||||
col = node.get_partition_column()
|
||||
data = data[:, col]
|
||||
self._down = data > 0
|
||||
|
||||
@staticmethod
|
||||
@@ -344,6 +359,7 @@ class Splitter:
|
||||
|
||||
def part(self, origin: np.array) -> list:
|
||||
"""Split an array in two based on indices (down) and its complement
|
||||
partition has to be called first to establish down indices
|
||||
|
||||
:param origin: dataset to split
|
||||
:type origin: np.array
|
||||
@@ -377,7 +393,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
tol: float = 1e-4,
|
||||
degree: int = 3,
|
||||
gamma="scale",
|
||||
split_criteria: str = "max_samples",
|
||||
split_criteria: str = "impurity",
|
||||
criterion: str = "gini",
|
||||
min_samples_split: int = 0,
|
||||
max_features=None,
|
||||
@@ -518,7 +534,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
impurity = self.splitter_.partition_impurity(y)
|
||||
node = Snode(clf, X, y, features, impurity, title, sample_weight)
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
self.splitter_.partition(X, node)
|
||||
self.splitter_.partition(X, node, True)
|
||||
X_U, X_D = self.splitter_.part(X)
|
||||
y_u, y_d = self.splitter_.part(y)
|
||||
sw_u, sw_d = self.splitter_.part(sample_weight)
|
||||
@@ -605,7 +621,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
# set a class for every sample in dataset
|
||||
prediction = np.full((xp.shape[0], 1), node._class)
|
||||
return prediction, indices
|
||||
self.splitter_.partition(xp, node)
|
||||
self.splitter_.partition(xp, node, train=False)
|
||||
x_u, x_d = self.splitter_.part(xp)
|
||||
i_u, i_d = self.splitter_.part(indices)
|
||||
prx_u, prin_u = predict_class(x_u, i_u, node.get_up())
|
||||
|
@@ -40,12 +40,13 @@ class Snode_test(unittest.TestCase):
|
||||
# Check Class
|
||||
class_computed = classes[card == max_card]
|
||||
self.assertEqual(class_computed, node._class)
|
||||
# Check Partition column
|
||||
self.assertEqual(node._partition_column, -1)
|
||||
|
||||
check_leave(self._clf.tree_)
|
||||
|
||||
def test_nodes_coefs(self):
|
||||
"""Check if the nodes of the tree have the right attributes filled
|
||||
"""
|
||||
"""Check if the nodes of the tree have the right attributes filled"""
|
||||
|
||||
def run_tree(node: Snode):
|
||||
if node._belief < 1:
|
||||
@@ -54,16 +55,19 @@ class Snode_test(unittest.TestCase):
|
||||
self.assertIsNotNone(node._clf.coef_)
|
||||
if node.is_leaf():
|
||||
return
|
||||
run_tree(node.get_down())
|
||||
run_tree(node.get_up())
|
||||
run_tree(node.get_down())
|
||||
|
||||
run_tree(self._clf.tree_)
|
||||
model = Stree(self._random_state)
|
||||
model.fit(*load_dataset(self._random_state, 3, 4))
|
||||
run_tree(model.tree_)
|
||||
|
||||
def test_make_predictor_on_leaf(self):
|
||||
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
|
||||
test.make_predictor()
|
||||
self.assertEqual(1, test._class)
|
||||
self.assertEqual(0.75, test._belief)
|
||||
self.assertEqual(-1, test._partition_column)
|
||||
|
||||
def test_make_predictor_on_not_leaf(self):
|
||||
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
|
||||
@@ -71,11 +75,14 @@ class Snode_test(unittest.TestCase):
|
||||
test.make_predictor()
|
||||
self.assertIsNone(test._class)
|
||||
self.assertEqual(0, test._belief)
|
||||
self.assertEqual(-1, test._partition_column)
|
||||
self.assertEqual(-1, test.get_up()._partition_column)
|
||||
|
||||
def test_make_predictor_on_leaf_bogus_data(self):
|
||||
test = Snode(None, [1, 2, 3, 4], [], [], 0.0, "test")
|
||||
test.make_predictor()
|
||||
self.assertIsNone(test._class)
|
||||
self.assertEqual(-1, test._partition_column)
|
||||
|
||||
def test_copy_node(self):
|
||||
px = [1, 2, 3, 4]
|
||||
@@ -86,3 +93,4 @@ class Snode_test(unittest.TestCase):
|
||||
self.assertListEqual(computed._y, py)
|
||||
self.assertEqual("test", computed._title)
|
||||
self.assertIsInstance(computed._clf, Stree)
|
||||
self.assertEqual(test._partition_column, computed._partition_column)
|
||||
|
@@ -134,13 +134,17 @@ class Splitter_test(unittest.TestCase):
|
||||
[0.7, 0.01, -0.1],
|
||||
[0.7, -0.9, 0.5],
|
||||
[0.1, 0.2, 0.3],
|
||||
[-0.1, 0.2, 0.3],
|
||||
[-0.1, 0.2, 0.3],
|
||||
]
|
||||
)
|
||||
expected = np.array([0.2, 0.01, -0.9, 0.2])
|
||||
y = [1, 2, 1, 0]
|
||||
expected = np.array([-0.1, 0.7, 0.7, 0.1, -0.1, -0.1])
|
||||
y = [1, 2, 1, 0, 0, 0]
|
||||
computed = tcl._max_samples(data, y)
|
||||
self.assertEqual((4,), computed.shape)
|
||||
self.assertListEqual(expected.tolist(), computed.tolist())
|
||||
self.assertEqual(0, computed)
|
||||
computed_data = data[:, computed]
|
||||
self.assertEqual((6,), computed_data.shape)
|
||||
self.assertListEqual(expected.tolist(), computed_data.tolist())
|
||||
|
||||
def test_impurity(self):
|
||||
tcl = self.build(criteria="impurity")
|
||||
@@ -150,12 +154,16 @@ class Splitter_test(unittest.TestCase):
|
||||
[0.7, 0.01, -0.1],
|
||||
[0.7, -0.9, 0.5],
|
||||
[0.1, 0.2, 0.3],
|
||||
[-0.1, 0.2, 0.3],
|
||||
[-0.1, 0.2, 0.3],
|
||||
]
|
||||
)
|
||||
expected = np.array([-0.1, 0.7, 0.7, 0.1])
|
||||
expected = np.array([0.2, 0.01, -0.9, 0.2, 0.2, 0.2])
|
||||
computed = tcl._impurity(data, None)
|
||||
self.assertEqual((4,), computed.shape)
|
||||
self.assertListEqual(expected.tolist(), computed.tolist())
|
||||
self.assertEqual(1, computed)
|
||||
computed_data = data[:, computed]
|
||||
self.assertEqual((6,), computed_data.shape)
|
||||
self.assertListEqual(expected.tolist(), computed_data.tolist())
|
||||
|
||||
def test_best_splitter_few_sets(self):
|
||||
X, y = load_iris(return_X_y=True)
|
||||
@@ -168,9 +176,9 @@ class Splitter_test(unittest.TestCase):
|
||||
def test_splitter_parameter(self):
|
||||
expected_values = [
|
||||
[0, 1, 7, 9], # best entropy max_samples
|
||||
[3, 8, 10, 11], # best entropy impurity
|
||||
[0, 2, 4, 5], # best entropy impurity
|
||||
[0, 2, 8, 12], # best gini max_samples
|
||||
[1, 2, 5, 12], # best gini impurity
|
||||
[4, 5, 9, 12], # best gini impurity
|
||||
[1, 2, 5, 10], # random entropy max_samples
|
||||
[4, 8, 9, 12], # random entropy impurity
|
||||
[3, 9, 11, 12], # random gini max_samples
|
||||
|
@@ -217,7 +217,7 @@ class Stree_test(unittest.TestCase):
|
||||
random_state=self._random_state,
|
||||
)
|
||||
clf.fit(px, py)
|
||||
print(f"{name} {criteria} {kernel}")
|
||||
# print(f"{name} {criteria} {kernel}")
|
||||
outcome = outcomes[name][f"{criteria} {kernel}"]
|
||||
self.assertAlmostEqual(outcome, clf.score(px, py))
|
||||
|
||||
@@ -310,23 +310,23 @@ class Stree_test(unittest.TestCase):
|
||||
def test_score_multi_class(self):
|
||||
warnings.filterwarnings("ignore")
|
||||
accuracies = [
|
||||
0.651685393258427, # Wine linear impurity
|
||||
0.7022472, # Wine linear impurity
|
||||
0.8314607, # Wine linear max_samples
|
||||
0.6629213483146067, # Wine rbf impurity
|
||||
0.4044944, # Wine rbf impurity
|
||||
0.4044944, # Wine rbf max_samples
|
||||
0.9157303, # Wine poly impurity
|
||||
0.3988764, # Wine poly impurity
|
||||
0.7640449, # Wine poly max_samples
|
||||
0.9933333, # Iris linear impurity
|
||||
0.6600000, # Iris linear impurity
|
||||
0.9666667, # Iris linear max_samples
|
||||
0.9800000, # Iris rbf impurity
|
||||
0.3333333, # Iris rbf impurity
|
||||
0.9800000, # Iris rbf max_samples
|
||||
1.0000000, # Iris poly impurity
|
||||
0.3333333, # Iris poly impurity
|
||||
1.0000000, # Iris poly max_samples
|
||||
0.8993333, # Synthetic linear impurity
|
||||
0.7153333, # Synthetic linear impurity
|
||||
0.9313333, # Synthetic linear max_samples
|
||||
0.8320000, # Synthetic rbf impurity
|
||||
0.4806667, # Synthetic rbf impurity
|
||||
0.8320000, # Synthetic rbf max_samples
|
||||
0.6066667, # Synthetic poly impurity
|
||||
0.4786667, # Synthetic poly impurity
|
||||
0.6340000, # Synthetic poly max_samples
|
||||
]
|
||||
datasets = [
|
||||
|
Reference in New Issue
Block a user