mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-16 16:06:01 +00:00
File diff suppressed because one or more lines are too long
290
stree/Strees.py
290
stree/Strees.py
@@ -9,12 +9,14 @@ Build an oblique tree classifier based on SVM Trees
|
|||||||
import os
|
import os
|
||||||
import numbers
|
import numbers
|
||||||
import random
|
import random
|
||||||
|
import warnings
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||||
from sklearn.svm import SVC, LinearSVC
|
from sklearn.svm import SVC, LinearSVC
|
||||||
from sklearn.utils import check_consistent_length
|
from sklearn.utils import check_consistent_length
|
||||||
from sklearn.utils.multiclass import check_classification_targets
|
from sklearn.utils.multiclass import check_classification_targets
|
||||||
|
from sklearn.exceptions import ConvergenceWarning
|
||||||
from sklearn.utils.validation import (
|
from sklearn.utils.validation import (
|
||||||
check_X_y,
|
check_X_y,
|
||||||
check_array,
|
check_array,
|
||||||
@@ -134,6 +136,168 @@ class Siterator:
|
|||||||
return node
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
class Splitter:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clf: SVC = None,
|
||||||
|
criterion: str = None,
|
||||||
|
splitter_type: str = None,
|
||||||
|
criteria: str = None,
|
||||||
|
min_samples_split: int = None,
|
||||||
|
random_state=None,
|
||||||
|
):
|
||||||
|
self._clf = clf
|
||||||
|
self._random_state = random_state
|
||||||
|
if random_state is not None:
|
||||||
|
random.seed(random_state)
|
||||||
|
self._criterion = criterion
|
||||||
|
self._min_samples_split = min_samples_split
|
||||||
|
self._criteria = criteria
|
||||||
|
self._splitter_type = splitter_type
|
||||||
|
|
||||||
|
if clf is None:
|
||||||
|
raise ValueError(f"clf has to be a sklearn estimator, got({clf})")
|
||||||
|
|
||||||
|
if criterion not in ["gini", "entropy"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"criterion must be gini or entropy got({criterion})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if criteria not in ["min_distance", "max_samples"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"split_criteria has to be min_distance or \
|
||||||
|
max_samples got ({criteria})"
|
||||||
|
)
|
||||||
|
|
||||||
|
if splitter_type not in ["random", "best"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"splitter must be either random or best got({splitter_type})"
|
||||||
|
)
|
||||||
|
self.criterion_function = getattr(self, f"_{self._criterion}")
|
||||||
|
self.decision_criteria = getattr(self, f"_{self._criteria}")
|
||||||
|
|
||||||
|
def impurity(self, y: np.array) -> np.array:
|
||||||
|
return self.criterion_function(y)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _gini(y: np.array) -> float:
|
||||||
|
_, count = np.unique(y, return_counts=True)
|
||||||
|
return 1 - np.sum(np.square(count / np.sum(count)))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _entropy(y: np.array) -> float:
|
||||||
|
_, count = np.unique(y, return_counts=True)
|
||||||
|
proportion = count / np.sum(count)
|
||||||
|
return -np.sum(proportion * np.log2(proportion))
|
||||||
|
|
||||||
|
def information_gain(
|
||||||
|
self, labels_up: np.array, labels_dn: np.array
|
||||||
|
) -> float:
|
||||||
|
card_up = labels_up.shape[0]
|
||||||
|
card_dn = labels_dn.shape[0]
|
||||||
|
samples = card_up + card_dn
|
||||||
|
up = card_up / samples * self.criterion_function(labels_up)
|
||||||
|
dn = card_dn / samples * self.criterion_function(labels_dn)
|
||||||
|
return up + dn
|
||||||
|
|
||||||
|
def _select_best_set(
|
||||||
|
self, dataset: np.array, labels: np.array, features_sets: list
|
||||||
|
) -> list:
|
||||||
|
min_impurity = 1
|
||||||
|
selected = None
|
||||||
|
warnings.filterwarnings("ignore", category=ConvergenceWarning)
|
||||||
|
for feature_set in features_sets:
|
||||||
|
self._clf.fit(dataset[:, feature_set], labels)
|
||||||
|
node = Snode(
|
||||||
|
self._clf, dataset, labels, feature_set, 0.0, "subset"
|
||||||
|
)
|
||||||
|
self.partition(dataset, node)
|
||||||
|
y1, y2 = self.part(labels)
|
||||||
|
impurity = self.information_gain(y1, y2)
|
||||||
|
if impurity < min_impurity:
|
||||||
|
min_impurity = impurity
|
||||||
|
selected = feature_set
|
||||||
|
return selected
|
||||||
|
|
||||||
|
def _get_subspaces_set(
|
||||||
|
self, dataset: np.array, labels: np.array, max_features: int
|
||||||
|
) -> np.array:
|
||||||
|
features = range(dataset.shape[1])
|
||||||
|
features_sets = list(combinations(features, max_features))
|
||||||
|
if len(features_sets) > 1:
|
||||||
|
if self._splitter_type == "random":
|
||||||
|
return features_sets[random.randint(0, len(features_sets) - 1)]
|
||||||
|
else:
|
||||||
|
return self._select_best_set(dataset, labels, features_sets)
|
||||||
|
else:
|
||||||
|
return features_sets[0]
|
||||||
|
|
||||||
|
def get_subspace(
|
||||||
|
self, dataset: np.array, labels: np.array, max_features: int
|
||||||
|
) -> list:
|
||||||
|
"""Return the best subspace to make a split
|
||||||
|
"""
|
||||||
|
indices = self._get_subspaces_set(dataset, labels, max_features)
|
||||||
|
return dataset[:, indices], indices
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _min_distance(data: np.array, _) -> np.array:
|
||||||
|
# chooses the lowest distance of every sample
|
||||||
|
indices = np.argmin(np.abs(data), axis=1)
|
||||||
|
return np.array(
|
||||||
|
[data[x, y] for x, y in zip(range(len(data[:, 0])), indices)]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _max_samples(data: np.array, y: np.array) -> np.array:
|
||||||
|
# select the class with max number of samples
|
||||||
|
_, samples = np.unique(y, return_counts=True)
|
||||||
|
selected = np.argmax(samples)
|
||||||
|
return data[:, selected]
|
||||||
|
|
||||||
|
def partition(self, samples: np.array, node: Snode):
|
||||||
|
"""Set the criteria to split arrays
|
||||||
|
|
||||||
|
"""
|
||||||
|
data = self._distances(node, samples)
|
||||||
|
if data.shape[0] < self._min_samples_split:
|
||||||
|
self._down = np.ones((data.shape[0]), dtype=bool)
|
||||||
|
return
|
||||||
|
if data.ndim > 1:
|
||||||
|
# split criteria for multiclass
|
||||||
|
data = self.decision_criteria(data, node._y)
|
||||||
|
self._down = data > 0
|
||||||
|
|
||||||
|
def _distances(self, node: Snode, data: np.ndarray) -> np.array:
|
||||||
|
"""Compute distances of the samples to the hyperplane of the node
|
||||||
|
|
||||||
|
:param node: node containing the svm classifier
|
||||||
|
:type node: Snode
|
||||||
|
:param data: samples to find out distance to hyperplane
|
||||||
|
:type data: np.ndarray
|
||||||
|
:return: array of shape (m, 1) with the distances of every sample to
|
||||||
|
the hyperplane of the node
|
||||||
|
:rtype: np.array
|
||||||
|
"""
|
||||||
|
return node._clf.decision_function(data[:, node._features])
|
||||||
|
|
||||||
|
def part(self, origin: np.array) -> list:
|
||||||
|
"""Split an array in two based on indices (down) and its complement
|
||||||
|
|
||||||
|
:param origin: dataset to split
|
||||||
|
:type origin: np.array
|
||||||
|
:param down: indices to use to split array
|
||||||
|
:type down: np.array
|
||||||
|
:return: list with two splits of the array
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
up = ~self._down
|
||||||
|
return [
|
||||||
|
origin[up] if any(up) else None,
|
||||||
|
origin[self._down] if any(self._down) else None,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Stree(BaseEstimator, ClassifierMixin):
|
class Stree(BaseEstimator, ClassifierMixin):
|
||||||
"""Estimator that is based on binary trees of svm nodes
|
"""Estimator that is based on binary trees of svm nodes
|
||||||
can deal with sample_weights in predict, used in boosting sklearn methods
|
can deal with sample_weights in predict, used in boosting sklearn methods
|
||||||
@@ -156,6 +320,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
criterion: str = "gini",
|
criterion: str = "gini",
|
||||||
min_samples_split: int = 0,
|
min_samples_split: int = 0,
|
||||||
max_features=None,
|
max_features=None,
|
||||||
|
splitter: str = "random",
|
||||||
):
|
):
|
||||||
self.max_iter = max_iter
|
self.max_iter = max_iter
|
||||||
self.C = C
|
self.C = C
|
||||||
@@ -169,6 +334,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
self.split_criteria = split_criteria
|
self.split_criteria = split_criteria
|
||||||
self.max_features = max_features
|
self.max_features = max_features
|
||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
|
self.splitter = splitter
|
||||||
|
|
||||||
def _more_tags(self) -> dict:
|
def _more_tags(self) -> dict:
|
||||||
"""Required by sklearn to supply features of the classifier
|
"""Required by sklearn to supply features of the classifier
|
||||||
@@ -178,68 +344,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
"""
|
"""
|
||||||
return {"requires_y": True}
|
return {"requires_y": True}
|
||||||
|
|
||||||
def _split_array(self, origin: np.array, down: np.array) -> list:
|
|
||||||
"""Split an array in two based on indices (down) and its complement
|
|
||||||
|
|
||||||
:param origin: dataset to split
|
|
||||||
:type origin: np.array
|
|
||||||
:param down: indices to use to split array
|
|
||||||
:type down: np.array
|
|
||||||
:return: list with two splits of the array
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
up = ~down
|
|
||||||
return [
|
|
||||||
origin[up] if any(up) else None,
|
|
||||||
origin[down] if any(down) else None,
|
|
||||||
]
|
|
||||||
|
|
||||||
def _distances(self, node: Snode, data: np.ndarray) -> np.array:
|
|
||||||
"""Compute distances of the samples to the hyperplane of the node
|
|
||||||
|
|
||||||
:param node: node containing the svm classifier
|
|
||||||
:type node: Snode
|
|
||||||
:param data: samples to find out distance to hyperplane
|
|
||||||
:type data: np.ndarray
|
|
||||||
:return: array of shape (m, 1) with the distances of every sample to
|
|
||||||
the hyperplane of the node
|
|
||||||
:rtype: np.array
|
|
||||||
"""
|
|
||||||
return node._clf.decision_function(data[:, node._features])
|
|
||||||
|
|
||||||
def _min_distance(self, data: np.array, _) -> np.array:
|
|
||||||
# chooses the lowest distance of every sample
|
|
||||||
indices = np.argmin(np.abs(data), axis=1)
|
|
||||||
return np.array(
|
|
||||||
[data[x, y] for x, y in zip(range(len(data[:, 0])), indices)]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _max_samples(self, data: np.array, y: np.array) -> np.array:
|
|
||||||
# select the class with max number of samples
|
|
||||||
_, samples = np.unique(y, return_counts=True)
|
|
||||||
selected = np.argmax(samples)
|
|
||||||
return data[:, selected]
|
|
||||||
|
|
||||||
def _split_criteria(self, data: np.array, node: Snode) -> np.array:
|
|
||||||
"""Set the criteria to split arrays
|
|
||||||
|
|
||||||
:param data: distances of samples to hyperplanes shape (m, nclasses)
|
|
||||||
if nclasses > 2 else (m,)
|
|
||||||
:type data: np.array
|
|
||||||
:param node: node containing the svm classifier
|
|
||||||
:type node: Snode
|
|
||||||
:return: array of booleans of samples under or above zero
|
|
||||||
:rtype: np.array
|
|
||||||
"""
|
|
||||||
|
|
||||||
if data.shape[0] < self.min_samples_split:
|
|
||||||
return np.ones((data.shape[0]), dtype=bool)
|
|
||||||
if data.ndim > 1:
|
|
||||||
# split criteria for multiclass
|
|
||||||
data = getattr(self, f"_{self.split_criteria}")(data, node._y)
|
|
||||||
res = data > 0
|
|
||||||
return res
|
|
||||||
|
|
||||||
def fit(
|
def fit(
|
||||||
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
|
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
|
||||||
) -> "Stree":
|
) -> "Stree":
|
||||||
@@ -271,21 +375,20 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
f"Maximum depth has to be greater than 1... got (max_depth=\
|
f"Maximum depth has to be greater than 1... got (max_depth=\
|
||||||
{self.max_depth})"
|
{self.max_depth})"
|
||||||
)
|
)
|
||||||
if self.split_criteria not in ["min_distance", "max_samples"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"split_criteria has to be min_distance or \
|
|
||||||
max_samples got ({self.split_criteria})"
|
|
||||||
)
|
|
||||||
if self.criterion not in ["gini", "entropy"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"criterion must be gini or entropy got({self.criterion})"
|
|
||||||
)
|
|
||||||
|
|
||||||
check_classification_targets(y)
|
check_classification_targets(y)
|
||||||
X, y = check_X_y(X, y)
|
X, y = check_X_y(X, y)
|
||||||
sample_weight = _check_sample_weight(sample_weight, X)
|
sample_weight = _check_sample_weight(sample_weight, X)
|
||||||
check_classification_targets(y)
|
check_classification_targets(y)
|
||||||
# Initialize computed parameters
|
# Initialize computed parameters
|
||||||
|
self.splitter_ = Splitter(
|
||||||
|
clf=self._build_clf(),
|
||||||
|
criterion=self.criterion,
|
||||||
|
splitter_type=self.splitter,
|
||||||
|
criteria=self.split_criteria,
|
||||||
|
random_state=self.random_state,
|
||||||
|
min_samples_split=self.min_samples_split,
|
||||||
|
)
|
||||||
if self.random_state is not None:
|
if self.random_state is not None:
|
||||||
random.seed(self.random_state)
|
random.seed(self.random_state)
|
||||||
self.classes_, y = np.unique(y, return_inverse=True)
|
self.classes_, y = np.unique(y, return_inverse=True)
|
||||||
@@ -295,7 +398,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
self.n_features_ = X.shape[1]
|
self.n_features_ = X.shape[1]
|
||||||
self.n_features_in_ = X.shape[1]
|
self.n_features_in_ = X.shape[1]
|
||||||
self.max_features_ = self._initialize_max_features()
|
self.max_features_ = self._initialize_max_features()
|
||||||
self.criterion_function_ = getattr(self, f"_{self.criterion}")
|
|
||||||
self.tree_ = self.train(X, y, sample_weight, 1, "root")
|
self.tree_ = self.train(X, y, sample_weight, 1, "root")
|
||||||
self._build_predictor()
|
self._build_predictor()
|
||||||
return self
|
return self
|
||||||
@@ -339,15 +441,15 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
)
|
)
|
||||||
# Train the model
|
# Train the model
|
||||||
clf = self._build_clf()
|
clf = self._build_clf()
|
||||||
Xs, features = self._get_subspace(X)
|
Xs, features = self.splitter_.get_subspace(X, y, self.max_features_)
|
||||||
clf.fit(Xs, y, sample_weight=sample_weight)
|
clf.fit(Xs, y, sample_weight=sample_weight)
|
||||||
impurity = self.criterion_function_(y)
|
impurity = self.splitter_.impurity(y)
|
||||||
node = Snode(clf, X, y, features, impurity, title)
|
node = Snode(clf, X, y, features, impurity, title)
|
||||||
self.depth_ = max(depth, self.depth_)
|
self.depth_ = max(depth, self.depth_)
|
||||||
down = self._split_criteria(self._distances(node, X), node)
|
self.splitter_.partition(X, node)
|
||||||
X_U, X_D = self._split_array(X, down)
|
X_U, X_D = self.splitter_.part(X)
|
||||||
y_u, y_d = self._split_array(y, down)
|
y_u, y_d = self.splitter_.part(y)
|
||||||
sw_u, sw_d = self._split_array(sample_weight, down)
|
sw_u, sw_d = self.splitter_.part(sample_weight)
|
||||||
if X_U is None or X_D is None:
|
if X_U is None or X_D is None:
|
||||||
# didn't part anything
|
# didn't part anything
|
||||||
return Snode(
|
return Snode(
|
||||||
@@ -431,9 +533,9 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
# set a class for every sample in dataset
|
# set a class for every sample in dataset
|
||||||
prediction = np.full((xp.shape[0], 1), node._class)
|
prediction = np.full((xp.shape[0], 1), node._class)
|
||||||
return prediction, indices
|
return prediction, indices
|
||||||
down = self._split_criteria(self._distances(node, xp), node)
|
self.splitter_.partition(xp, node)
|
||||||
x_u, x_d = self._split_array(xp, down)
|
x_u, x_d = self.splitter_.part(xp)
|
||||||
i_u, i_d = self._split_array(indices, down)
|
i_u, i_d = self.splitter_.part(indices)
|
||||||
prx_u, prin_u = predict_class(x_u, i_u, node.get_up())
|
prx_u, prin_u = predict_class(x_u, i_u, node.get_up())
|
||||||
prx_d, prin_d = predict_class(x_d, i_d, node.get_down())
|
prx_d, prin_d = predict_class(x_d, i_d, node.get_down())
|
||||||
return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
|
return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
|
||||||
@@ -536,29 +638,3 @@ class Stree(BaseEstimator, ClassifierMixin):
|
|||||||
f"got ({self.max_features})"
|
f"got ({self.max_features})"
|
||||||
)
|
)
|
||||||
return max_features
|
return max_features
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _gini(y: np.array) -> float:
|
|
||||||
_, count = np.unique(y, return_counts=True)
|
|
||||||
return 1 - np.sum(np.square(count / np.sum(count)))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _entropy(y: np.array) -> float:
|
|
||||||
_, count = np.unique(y, return_counts=True)
|
|
||||||
proportion = count / np.sum(count)
|
|
||||||
return -np.sum(proportion * np.log2(proportion))
|
|
||||||
|
|
||||||
def _get_subspace(self, dataset: np.array) -> list:
|
|
||||||
"""Return the best subspace to make a split
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_subspaces_set(dataset: np.array) -> np.array:
|
|
||||||
features = range(dataset.shape[1])
|
|
||||||
features_sets = list(combinations(features, self.max_features_))
|
|
||||||
if len(features_sets) > 1:
|
|
||||||
return features_sets[random.randint(0, len(features_sets) - 1)]
|
|
||||||
else:
|
|
||||||
return features_sets[0]
|
|
||||||
|
|
||||||
indices = get_subspaces_set(dataset)
|
|
||||||
return dataset[:, indices], indices
|
|
||||||
|
@@ -1,3 +1,3 @@
|
|||||||
from .Strees import Stree, Snode, Siterator
|
from .Strees import Stree, Snode, Siterator, Splitter
|
||||||
|
|
||||||
__all__ = ["Stree", "Snode", "Siterator"]
|
__all__ = ["Stree", "Snode", "Siterator", "Splitter"]
|
||||||
|
142
stree/tests/Splitter_test.py
Normal file
142
stree/tests/Splitter_test.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.svm import LinearSVC
|
||||||
|
|
||||||
|
from stree import Splitter
|
||||||
|
from .utils import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
class Splitter_test(unittest.TestCase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._random_state = 1
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
clf=LinearSVC(),
|
||||||
|
min_samples_split=0,
|
||||||
|
splitter_type="random",
|
||||||
|
criterion="gini",
|
||||||
|
criteria="min_distance",
|
||||||
|
random_state=None,
|
||||||
|
):
|
||||||
|
return Splitter(
|
||||||
|
clf=clf,
|
||||||
|
min_samples_split=min_samples_split,
|
||||||
|
splitter_type=splitter_type,
|
||||||
|
criterion=criterion,
|
||||||
|
criteria=criteria,
|
||||||
|
random_state=random_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUp(cls):
|
||||||
|
os.environ["TESTING"] = "1"
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.build(criterion="duck")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.build(splitter_type="duck")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.build(criteria="duck")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.build(clf=None)
|
||||||
|
for splitter_type in ["best", "random"]:
|
||||||
|
for criterion in ["gini", "entropy"]:
|
||||||
|
for criteria in ["min_distance", "max_samples"]:
|
||||||
|
tcl = self.build(
|
||||||
|
splitter_type=splitter_type,
|
||||||
|
criterion=criterion,
|
||||||
|
criteria=criteria,
|
||||||
|
)
|
||||||
|
self.assertEqual(splitter_type, tcl._splitter_type)
|
||||||
|
self.assertEqual(criterion, tcl._criterion)
|
||||||
|
self.assertEqual(criteria, tcl._criteria)
|
||||||
|
|
||||||
|
def test_gini(self):
|
||||||
|
y = [0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
|
||||||
|
expected = 0.48
|
||||||
|
self.assertEqual(expected, Splitter._gini(y))
|
||||||
|
tcl = self.build(criterion="gini")
|
||||||
|
self.assertEqual(expected, tcl.criterion_function(y))
|
||||||
|
|
||||||
|
def test_entropy(self):
|
||||||
|
y = [0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
|
||||||
|
expected = 0.9709505944546686
|
||||||
|
self.assertAlmostEqual(expected, Splitter._entropy(y))
|
||||||
|
tcl = self.build(criterion="entropy")
|
||||||
|
self.assertEqual(expected, tcl.criterion_function(y))
|
||||||
|
|
||||||
|
def test_information_gain(self):
|
||||||
|
yu = np.array([0, 1, 1, 1, 1, 1])
|
||||||
|
yd = np.array([0, 0, 0, 1])
|
||||||
|
values_expected = [
|
||||||
|
("gini", 0.31666666666666665),
|
||||||
|
("entropy", 0.7145247027726656),
|
||||||
|
]
|
||||||
|
for criterion, expected in values_expected:
|
||||||
|
tcl = self.build(criterion=criterion)
|
||||||
|
computed = tcl.information_gain(yu, yd)
|
||||||
|
self.assertAlmostEqual(expected, computed)
|
||||||
|
|
||||||
|
def test_max_samples(self):
|
||||||
|
tcl = self.build(criteria="max_samples")
|
||||||
|
data = np.array(
|
||||||
|
[
|
||||||
|
[-0.1, 0.2, -0.3],
|
||||||
|
[0.7, 0.01, -0.1],
|
||||||
|
[0.7, -0.9, 0.5],
|
||||||
|
[0.1, 0.2, 0.3],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected = np.array([0.2, 0.01, -0.9, 0.2])
|
||||||
|
y = [1, 2, 1, 0]
|
||||||
|
computed = tcl._max_samples(data, y)
|
||||||
|
self.assertEqual((4,), computed.shape)
|
||||||
|
self.assertListEqual(expected.tolist(), computed.tolist())
|
||||||
|
|
||||||
|
def test_min_distance(self):
|
||||||
|
tcl = self.build()
|
||||||
|
data = np.array(
|
||||||
|
[
|
||||||
|
[-0.1, 0.2, -0.3],
|
||||||
|
[0.7, 0.01, -0.1],
|
||||||
|
[0.7, -0.9, 0.5],
|
||||||
|
[0.1, 0.2, 0.3],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
expected = np.array([-0.1, 0.01, 0.5, 0.1])
|
||||||
|
computed = tcl._min_distance(data, None)
|
||||||
|
self.assertEqual((4,), computed.shape)
|
||||||
|
self.assertListEqual(expected.tolist(), computed.tolist())
|
||||||
|
|
||||||
|
def test_splitter_parameter(self):
|
||||||
|
expected_values = [
|
||||||
|
[1, 7, 9],
|
||||||
|
[1, 7, 9],
|
||||||
|
[1, 7, 9],
|
||||||
|
[1, 7, 9],
|
||||||
|
[0, 5, 6],
|
||||||
|
[0, 5, 6],
|
||||||
|
[0, 5, 6],
|
||||||
|
[0, 5, 6],
|
||||||
|
]
|
||||||
|
X, y = load_dataset(self._random_state, n_features=12)
|
||||||
|
for splitter_type in ["best", "random"]:
|
||||||
|
for criterion in ["gini", "entropy"]:
|
||||||
|
for criteria in ["min_distance", "max_samples"]:
|
||||||
|
tcl = self.build(
|
||||||
|
splitter_type=splitter_type,
|
||||||
|
criterion=criterion,
|
||||||
|
criteria=criteria,
|
||||||
|
random_state=self._random_state,
|
||||||
|
)
|
||||||
|
expected = expected_values.pop(0)
|
||||||
|
dataset, computed = tcl.get_subspace(X, y, max_features=3)
|
||||||
|
self.assertListEqual(expected, list(computed))
|
||||||
|
self.assertListEqual(
|
||||||
|
X[:, computed].tolist(), dataset.tolist()
|
||||||
|
)
|
@@ -204,13 +204,11 @@ class Stree_test(unittest.TestCase):
|
|||||||
self.assertEqual(0, len(list(tcl)))
|
self.assertEqual(0, len(list(tcl)))
|
||||||
|
|
||||||
def test_min_samples_split(self):
|
def test_min_samples_split(self):
|
||||||
tcl_split = Stree(min_samples_split=3)
|
|
||||||
tcl_nosplit = Stree(min_samples_split=4)
|
|
||||||
dataset = [[1], [2], [3]], [1, 1, 0]
|
dataset = [[1], [2], [3]], [1, 1, 0]
|
||||||
tcl_split.fit(*dataset)
|
tcl_split = Stree(min_samples_split=3).fit(*dataset)
|
||||||
self.assertIsNotNone(tcl_split.tree_.get_down())
|
self.assertIsNotNone(tcl_split.tree_.get_down())
|
||||||
self.assertIsNotNone(tcl_split.tree_.get_up())
|
self.assertIsNotNone(tcl_split.tree_.get_up())
|
||||||
tcl_nosplit.fit(*dataset)
|
tcl_nosplit = Stree(min_samples_split=4).fit(*dataset)
|
||||||
self.assertIsNone(tcl_nosplit.tree_.get_down())
|
self.assertIsNone(tcl_nosplit.tree_.get_down())
|
||||||
self.assertIsNone(tcl_nosplit.tree_.get_up())
|
self.assertIsNone(tcl_nosplit.tree_.get_up())
|
||||||
|
|
||||||
@@ -265,37 +263,6 @@ class Stree_test(unittest.TestCase):
|
|||||||
outcome = outcomes[name][f"{criteria} {kernel}"]
|
outcome = outcomes[name][f"{criteria} {kernel}"]
|
||||||
self.assertAlmostEqual(outcome, clf.score(px, py))
|
self.assertAlmostEqual(outcome, clf.score(px, py))
|
||||||
|
|
||||||
def test_min_distance(self):
|
|
||||||
clf = Stree()
|
|
||||||
data = np.array(
|
|
||||||
[
|
|
||||||
[-0.1, 0.2, -0.3],
|
|
||||||
[0.7, 0.01, -0.1],
|
|
||||||
[0.7, -0.9, 0.5],
|
|
||||||
[0.1, 0.2, 0.3],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
expected = np.array([-0.1, 0.01, 0.5, 0.1])
|
|
||||||
computed = clf._min_distance(data, None)
|
|
||||||
self.assertEqual((4,), computed.shape)
|
|
||||||
self.assertListEqual(expected.tolist(), computed.tolist())
|
|
||||||
|
|
||||||
def test_max_samples(self):
|
|
||||||
clf = Stree()
|
|
||||||
data = np.array(
|
|
||||||
[
|
|
||||||
[-0.1, 0.2, -0.3],
|
|
||||||
[0.7, 0.01, -0.1],
|
|
||||||
[0.7, -0.9, 0.5],
|
|
||||||
[0.1, 0.2, 0.3],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
expected = np.array([0.2, 0.01, -0.9, 0.2])
|
|
||||||
y = [1, 2, 1, 0]
|
|
||||||
computed = clf._max_samples(data, y)
|
|
||||||
self.assertEqual((4,), computed.shape)
|
|
||||||
self.assertListEqual(expected.tolist(), computed.tolist())
|
|
||||||
|
|
||||||
def test_max_features(self):
|
def test_max_features(self):
|
||||||
n_features = 16
|
n_features = 16
|
||||||
expected_values = [
|
expected_values = [
|
||||||
@@ -334,7 +301,9 @@ class Stree_test(unittest.TestCase):
|
|||||||
for max_features, expected in expected_values:
|
for max_features, expected in expected_values:
|
||||||
clf.set_params(**dict(max_features=max_features))
|
clf.set_params(**dict(max_features=max_features))
|
||||||
clf.fit(dataset, y)
|
clf.fit(dataset, y)
|
||||||
computed, indices = clf._get_subspace(dataset)
|
computed, indices = clf.splitter_.get_subspace(
|
||||||
|
dataset, y, clf.max_features_
|
||||||
|
)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
dataset[:, indices].tolist(), computed.tolist()
|
dataset[:, indices].tolist(), computed.tolist()
|
||||||
)
|
)
|
||||||
@@ -345,22 +314,6 @@ class Stree_test(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
clf.fit(*load_dataset())
|
clf.fit(*load_dataset())
|
||||||
|
|
||||||
def test_gini(self):
|
|
||||||
y = [0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
|
|
||||||
expected = 0.48
|
|
||||||
self.assertEqual(expected, Stree._gini(y))
|
|
||||||
clf = Stree(criterion="gini")
|
|
||||||
clf.fit(*load_dataset())
|
|
||||||
self.assertEqual(expected, clf.criterion_function_(y))
|
|
||||||
|
|
||||||
def test_entropy(self):
|
|
||||||
y = [0, 1, 1, 1, 1, 1, 0, 0, 0, 1]
|
|
||||||
expected = 0.9709505944546686
|
|
||||||
self.assertAlmostEqual(expected, Stree._entropy(y))
|
|
||||||
clf = Stree(criterion="entropy")
|
|
||||||
clf.fit(*load_dataset())
|
|
||||||
self.assertEqual(expected, clf.criterion_function_(y))
|
|
||||||
|
|
||||||
def test_predict_feature_dimensions(self):
|
def test_predict_feature_dimensions(self):
|
||||||
X = np.random.rand(10, 5)
|
X = np.random.rand(10, 5)
|
||||||
y = np.random.randint(0, 2, 10)
|
y = np.random.randint(0, 2, 10)
|
||||||
@@ -374,3 +327,8 @@ class Stree_test(unittest.TestCase):
|
|||||||
clf = Stree(random_state=self._random_state, max_features=2)
|
clf = Stree(random_state=self._random_state, max_features=2)
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
self.assertAlmostEqual(0.9426666666666667, clf.score(X, y))
|
self.assertAlmostEqual(0.9426666666666667, clf.score(X, y))
|
||||||
|
|
||||||
|
def test_bogus_splitter_parameter(self):
|
||||||
|
clf = Stree(splitter="duck")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
clf.fit(*load_dataset())
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
from .Stree_test import Stree_test
|
from .Stree_test import Stree_test
|
||||||
from .Snode_test import Snode_test
|
from .Snode_test import Snode_test
|
||||||
|
from .Splitter_test import Splitter_test
|
||||||
|
|
||||||
__all__ = ["Stree_test", "Snode_test"]
|
__all__ = ["Stree_test", "Snode_test", "Splitter_test"]
|
||||||
|
@@ -1,10 +1,10 @@
|
|||||||
from sklearn.datasets import make_classification
|
from sklearn.datasets import make_classification
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(random_state=0, n_classes=2):
|
def load_dataset(random_state=0, n_classes=2, n_features=3):
|
||||||
X, y = make_classification(
|
X, y = make_classification(
|
||||||
n_samples=1500,
|
n_samples=1500,
|
||||||
n_features=3,
|
n_features=n_features,
|
||||||
n_informative=3,
|
n_informative=3,
|
||||||
n_redundant=0,
|
n_redundant=0,
|
||||||
n_repeated=0,
|
n_repeated=0,
|
||||||
|
Reference in New Issue
Block a user