mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
# 2 - add max_features parameters
This commit is contained in:
@@ -7,7 +7,9 @@ Build an oblique tree classifier based on SVM Trees
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import numbers
|
||||
import random
|
||||
from itertools import combinations
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.svm import SVC, LinearSVC
|
||||
@@ -127,8 +129,9 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
tol: float = 1e-4,
|
||||
degree: int = 3,
|
||||
gamma="scale",
|
||||
split_criteria="max_samples",
|
||||
split_criteria: str = "max_samples",
|
||||
min_samples_split: int = 0,
|
||||
max_features=None,
|
||||
):
|
||||
self.max_iter = max_iter
|
||||
self.C = C
|
||||
@@ -140,6 +143,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.degree = degree
|
||||
self.min_samples_split = min_samples_split
|
||||
self.split_criteria = split_criteria
|
||||
self.max_features = max_features
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
"""Required by sklearn to supply features of the classifier
|
||||
@@ -160,10 +164,10 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
:rtype: list
|
||||
"""
|
||||
up = ~down
|
||||
return (
|
||||
return [
|
||||
origin[up] if any(up) else None,
|
||||
origin[down] if any(down) else None,
|
||||
)
|
||||
]
|
||||
|
||||
def _distances(self, node: Snode, data: np.ndarray) -> np.array:
|
||||
"""Compute distances of the samples to the hyperplane of the node
|
||||
@@ -257,7 +261,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.n_classes_ = self.classes_.shape[0]
|
||||
self.n_iter_ = self.max_iter
|
||||
self.depth_ = 0
|
||||
self.n_features_in_ = X.shape[1]
|
||||
self.n_features_ = X.shape[1]
|
||||
self.max_features_ = self._initialize_max_features()
|
||||
self.tree_ = self.train(X, y, sample_weight, 1, "root")
|
||||
self._build_predictor()
|
||||
return self
|
||||
@@ -294,10 +299,11 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
return Snode(None, X, y, title + ", <pure>")
|
||||
# Train the model
|
||||
clf = self._build_clf()
|
||||
clf.fit(X, y, sample_weight=sample_weight)
|
||||
node = Snode(clf, X, y, title)
|
||||
Xs, indices_subset = self._get_subspace(X)
|
||||
clf.fit(Xs, y, sample_weight=sample_weight)
|
||||
node = Snode(clf, Xs, y, title)
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
down = self._split_criteria(self._distances(node, X), node)
|
||||
down = self._split_criteria(self._distances(node, Xs), node)
|
||||
X_U, X_D = self._split_array(X, down)
|
||||
y_u, y_d = self._split_array(y, down)
|
||||
sw_u, sw_d = self._split_array(sample_weight, down)
|
||||
@@ -446,3 +452,49 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
for i in self:
|
||||
output += str(i) + "\n"
|
||||
return output
|
||||
|
||||
def _initialize_max_features(self) -> int:
|
||||
if isinstance(self.max_features, str):
|
||||
if self.max_features == "auto":
|
||||
max_features = max(1, int(np.sqrt(self.n_features_)))
|
||||
elif self.max_features == "sqrt":
|
||||
max_features = max(1, int(np.sqrt(self.n_features_)))
|
||||
elif self.max_features == "log2":
|
||||
max_features = max(1, int(np.log2(self.n_features_)))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid value for max_features. "
|
||||
"Allowed string values are 'auto', "
|
||||
"'sqrt' or 'log2'."
|
||||
)
|
||||
elif self.max_features is None:
|
||||
max_features = self.n_features_
|
||||
elif isinstance(self.max_features, numbers.Integral):
|
||||
max_features = self.max_features
|
||||
else: # float
|
||||
if self.max_features > 0.0:
|
||||
max_features = max(
|
||||
1, int(self.max_features * self.n_features_)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid value for max_features."
|
||||
"Allowed float must be in range (0, 1] "
|
||||
f"got ({self.max_features})"
|
||||
)
|
||||
return max_features
|
||||
|
||||
def _get_subspace(self, dataset: np.array) -> list:
|
||||
"""Return the best subspace to make a split
|
||||
"""
|
||||
|
||||
def get_subspaces_set(dataset: np.array) -> np.array:
|
||||
features = range(dataset.shape[1])
|
||||
features_sets = list(combinations(features, self.max_features_))
|
||||
if len(features_sets) > 1:
|
||||
return features_sets[random.randint(0, len(features_sets))]
|
||||
else:
|
||||
return features_sets[0]
|
||||
|
||||
indices = get_subspaces_set(dataset)
|
||||
return dataset[:, indices], indices
|
||||
|
@@ -295,3 +295,47 @@ class Stree_test(unittest.TestCase):
|
||||
computed = clf._max_samples(data, y)
|
||||
self.assertEqual((4,), computed.shape)
|
||||
self.assertListEqual(expected.tolist(), computed.tolist())
|
||||
|
||||
def test_max_features(self):
|
||||
n_features = 16
|
||||
expected_values = [
|
||||
("auto", 4),
|
||||
("log2", 4),
|
||||
("sqrt", 4),
|
||||
(0.5, 8),
|
||||
(3, 3),
|
||||
(None, 16),
|
||||
]
|
||||
clf = Stree()
|
||||
clf.n_features_ = n_features
|
||||
for max_features, expected in expected_values:
|
||||
clf.set_params(**dict(max_features=max_features))
|
||||
computed = clf._initialize_max_features()
|
||||
self.assertEqual(expected, computed)
|
||||
# Check bogus max_features
|
||||
values = ["duck", -0.1, 0.0]
|
||||
for max_features in values:
|
||||
clf.set_params(**dict(max_features=max_features))
|
||||
with self.assertRaises(ValueError):
|
||||
_ = clf._initialize_max_features()
|
||||
|
||||
def test_get_subspaces(self):
|
||||
dataset = np.random.random((10, 16))
|
||||
y = np.random.randint(0, 2, 10)
|
||||
expected_values = [
|
||||
("auto", 4),
|
||||
("log2", 4),
|
||||
("sqrt", 4),
|
||||
(0.5, 8),
|
||||
(3, 3),
|
||||
(None, 16),
|
||||
]
|
||||
clf = Stree()
|
||||
for max_features, expected in expected_values:
|
||||
clf.set_params(**dict(max_features=max_features))
|
||||
clf.fit(dataset, y)
|
||||
computed, indices = clf._get_subspace(dataset)
|
||||
self.assertListEqual(
|
||||
dataset[:, indices].tolist(), computed.tolist()
|
||||
)
|
||||
self.assertEqual(expected, len(indices))
|
||||
|
Reference in New Issue
Block a user