mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-17 08:26:00 +00:00
Document & lint code
This commit is contained in:
223
stree/Strees.py
223
stree/Strees.py
@@ -7,23 +7,28 @@ Build an oblique tree classifier based on SVM Trees
|
||||
Uses LinearSVC
|
||||
'''
|
||||
|
||||
import typing
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.svm import LinearSVC
|
||||
from sklearn.utils.multiclass import check_classification_targets
|
||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted, _check_sample_weight, check_random_state
|
||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted, \
|
||||
_check_sample_weight
|
||||
|
||||
|
||||
class Snode:
|
||||
def __init__(self, clf: LinearSVC, X: np.ndarray, y: np.ndarray, title: str):
|
||||
"""Nodes of the tree that keeps the svm classifier and if testing the
|
||||
dataset assigned to it
|
||||
"""
|
||||
|
||||
def __init__(self, clf: LinearSVC, X: np.ndarray, y: np.ndarray,
|
||||
title: str):
|
||||
self._clf = clf
|
||||
self._vector = None if clf is None else clf.coef_
|
||||
self._interceptor = 0. if clf is None else clf.intercept_
|
||||
self._title = title
|
||||
self._belief = 0. # belief of the prediction in a leaf node based on samples
|
||||
self._belief = 0.
|
||||
# Only store dataset in Testing
|
||||
self._X = X if os.environ.get('TESTING', 'NS') != 'NS' else None
|
||||
self._y = y
|
||||
@@ -51,8 +56,8 @@ class Snode:
|
||||
return self._up
|
||||
|
||||
def make_predictor(self):
|
||||
"""Compute the class of the predictor and its belief based on the subdataset of the node
|
||||
only if it is a leaf
|
||||
"""Compute the class of the predictor and its belief based on the
|
||||
subdataset of the node only if it is a leaf
|
||||
"""
|
||||
if not self.is_leaf():
|
||||
return
|
||||
@@ -62,7 +67,7 @@ class Snode:
|
||||
min_card = min(card)
|
||||
try:
|
||||
self._belief = max_card / (max_card + min_card)
|
||||
except:
|
||||
except ZeroDivisionError:
|
||||
self._belief = 0.
|
||||
self._class = classes[card == max_card][0]
|
||||
else:
|
||||
@@ -71,7 +76,10 @@ class Snode:
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.is_leaf():
|
||||
return f"{self._title} - Leaf class={self._class} belief={self._belief:.6f} counts={np.unique(self._y, return_counts=True)}"
|
||||
count_values = np.unique(self._y, return_counts=True)
|
||||
result = f"{self._title} - Leaf class={self._class} belief="\
|
||||
f"{self._belief: .6f} counts={count_values}"
|
||||
return result
|
||||
else:
|
||||
return f"{self._title}"
|
||||
|
||||
@@ -101,11 +109,16 @@ class Siterator:
|
||||
|
||||
|
||||
class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
"""Estimator that is based on binary trees of svm nodes
|
||||
can deal with sample_weights in predict, used in boosting sklearn methods
|
||||
inheriting from BaseEstimator implements get_params and set_params methods
|
||||
inheriting from ClassifierMixin implement the attribute _estimator_type
|
||||
with "classifier" as value
|
||||
"""
|
||||
|
||||
def __init__(self, C: float = 1.0, max_iter: int = 1000, random_state: int = None,
|
||||
max_depth: int=None, tol: float=1e-4, use_predictions: bool = False):
|
||||
def __init__(self, C: float = 1.0, max_iter: int = 1000,
|
||||
random_state: int = None, max_depth: int = None,
|
||||
tol: float = 1e-4, use_predictions: bool = False):
|
||||
self.max_iter = max_iter
|
||||
self.C = C
|
||||
self.random_state = random_state
|
||||
@@ -113,65 +126,100 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.max_depth = max_depth
|
||||
self.tol = tol
|
||||
|
||||
def get_params(self, deep: bool=True) -> dict:
|
||||
"""Get dict with hyperparameters and its values to accomplish sklearn rules
|
||||
"""
|
||||
return {
|
||||
'C': self.C,
|
||||
'random_state': self.random_state,
|
||||
'max_iter': self.max_iter,
|
||||
'use_predictions': self.use_predictions,
|
||||
'max_depth': self.max_depth,
|
||||
'tol': self.tol
|
||||
}
|
||||
|
||||
def set_params(self, **parameters: dict):
|
||||
"""Set hyperparmeters as specified by sklearn, needed in Gridsearchs
|
||||
"""
|
||||
for parameter, value in parameters.items():
|
||||
setattr(self, parameter, value)
|
||||
return self
|
||||
|
||||
# Added binary_only tag as required by sklearn check_estimator
|
||||
def _more_tags(self) -> dict:
|
||||
return {'binary_only': True}
|
||||
"""Required by sklearn to tell that this estimator is a binary classifier
|
||||
|
||||
:return: the tag required
|
||||
:rtype: dict
|
||||
"""
|
||||
return {'binary_only': True, 'requires_y': True}
|
||||
|
||||
def _linear_function(self, data: np.array, node: Snode) -> np.array:
|
||||
"""Compute the distance of set of samples to a hyperplane, in
|
||||
multiclass classification it should compute the distance to a
|
||||
hyperplane of each class
|
||||
|
||||
:param data: dataset of samples
|
||||
:type data: np.array
|
||||
:param node: the node that contains the hyperplance coefficients
|
||||
:type node: Snode
|
||||
:return: array of distances of each sample to the hyperplane
|
||||
:rtype: np.array
|
||||
"""
|
||||
coef = node._vector[0, :].reshape(-1, data.shape[1])
|
||||
return data.dot(coef.T) + node._interceptor[0]
|
||||
|
||||
def _split_array(self, origin: np.array, down: np.array) -> list:
|
||||
"""Split an array in two based on indices passed as down and its complement
|
||||
|
||||
:param origin: dataset to split
|
||||
:type origin: np.array
|
||||
:param down: indices to use to split array
|
||||
:type down: np.array
|
||||
:return: list with two splits of the array
|
||||
:rtype: list
|
||||
"""
|
||||
up = ~down
|
||||
return origin[up[:, 0]] if any(up) else None, \
|
||||
origin[down[:, 0]] if any(down) else None
|
||||
|
||||
def _distances(self, node: Snode, data: np.ndarray) -> np.array:
|
||||
"""Compute distances of the samples to the hyperplane of the node
|
||||
|
||||
:param node: node containing the svm classifier
|
||||
:type node: Snode
|
||||
:param data: samples to find out distance to hyperplane
|
||||
:type data: np.ndarray
|
||||
:return: array of shape (m, 1) with the distances of every sample to
|
||||
the hyperplane of the node
|
||||
:rtype: np.array
|
||||
"""
|
||||
if self.use_predictions:
|
||||
res = np.expand_dims(node._clf.decision_function(data), 1)
|
||||
else:
|
||||
# doesn't work with multiclass as each sample has to do inner product with its own coeficients
|
||||
# computes positition of every sample is w.r.t. the hyperplane
|
||||
"""doesn't work with multiclass as each sample has to do inner
|
||||
product with its own coefficients computes positition of every
|
||||
sample is w.r.t. the hyperplane
|
||||
"""
|
||||
res = self._linear_function(data, node)
|
||||
return res
|
||||
|
||||
def _split_criteria(self, data: np.array) -> np.array:
|
||||
"""Set the criteria to split arrays
|
||||
|
||||
:param data: [description]
|
||||
:type data: np.array
|
||||
:return: [description]
|
||||
:rtype: np.array
|
||||
"""
|
||||
return data > 0
|
||||
|
||||
def fit(self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None) -> 'Stree':
|
||||
def fit(self, X: np.ndarray, y: np.ndarray,
|
||||
sample_weight: np.array = None) -> 'Stree':
|
||||
"""Build the tree based on the dataset of samples and its labels
|
||||
|
||||
:raises ValueError: if parameters C or max_depth are out of bounds
|
||||
:return: itself to be able to chain actions: fit().predict() ...
|
||||
:rtype: Stree
|
||||
"""
|
||||
# Check parameters are Ok.
|
||||
if type(y).__name__ == 'np.ndarray':
|
||||
y = y.ravel()
|
||||
if self.C < 0:
|
||||
raise ValueError(f"Penalty term must be positive... got (C={self.C:f})")
|
||||
self.__max_depth = np.iinfo(np.int32).max if self.max_depth is None else self.max_depth
|
||||
raise ValueError(
|
||||
f"Penalty term must be positive... got (C={self.C:f})")
|
||||
self.__max_depth = np.iinfo(
|
||||
np.int32).max if self.max_depth is None else self.max_depth
|
||||
if self.__max_depth < 1:
|
||||
raise ValueError(f"Maximum depth has to be greater than 1... got (max_depth={self.max_depth})")
|
||||
raise ValueError(
|
||||
f"Maximum depth has to be greater than 1... got (max_depth=\
|
||||
{self.max_depth})")
|
||||
check_classification_targets(y)
|
||||
X, y = check_X_y(X, y)
|
||||
sample_weight = _check_sample_weight(sample_weight, X)
|
||||
check_classification_targets(y)
|
||||
# Initialize computed parameters
|
||||
self.classes_ = np.unique(y)
|
||||
self.classes_, y = np.unique(y, return_inverse=True)
|
||||
self.n_iter_ = self.max_iter
|
||||
self.depth_ = 0
|
||||
self.n_features_in_ = X.shape[1]
|
||||
@@ -182,7 +230,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
def _build_predictor(self):
|
||||
"""Process the leaves to make them predictors
|
||||
"""
|
||||
|
||||
def run_tree(node: Snode):
|
||||
if node.is_leaf():
|
||||
node.make_predictor()
|
||||
@@ -192,16 +239,32 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
|
||||
run_tree(self.tree_)
|
||||
|
||||
def train(self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray, depth: int, title: str) -> Snode:
|
||||
|
||||
def train(self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray,
|
||||
depth: int, title: str) -> Snode:
|
||||
"""Recursive function to split the original dataset into predictor
|
||||
nodes (leaves)
|
||||
|
||||
:param X: samples dataset
|
||||
:type X: np.ndarray
|
||||
:param y: samples labels
|
||||
:type y: np.ndarray
|
||||
:param sample_weight: weight of samples (used in boosting)
|
||||
:type sample_weight: np.ndarray
|
||||
:param depth: actual depth in the tree
|
||||
:type depth: int
|
||||
:param title: description of the node
|
||||
:type title: str
|
||||
:return: binary tree
|
||||
:rtype: Snode
|
||||
"""
|
||||
if depth > self.__max_depth:
|
||||
return None
|
||||
if np.unique(y).shape[0] == 1 :
|
||||
if np.unique(y).shape[0] == 1:
|
||||
# only 1 class => pure dataset
|
||||
return Snode(None, X, y, title + ', <pure>')
|
||||
# Train the model
|
||||
clf = LinearSVC(max_iter=self.max_iter, random_state=self.random_state,
|
||||
C=self.C) #, sample_weight=sample_weight)
|
||||
C=self.C) # , sample_weight=sample_weight)
|
||||
clf.fit(X, y, sample_weight=sample_weight)
|
||||
tree = Snode(clf, X, y, title)
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
@@ -217,6 +280,15 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
return tree
|
||||
|
||||
def _reorder_results(self, y: np.array, indices: np.array) -> np.array:
|
||||
"""Reorder an array based on the array of indices passed
|
||||
|
||||
:param y: data untidy
|
||||
:type y: np.array
|
||||
:param indices: indices used to set order
|
||||
:type indices: np.array
|
||||
:return: array y ordered
|
||||
:rtype: np.array
|
||||
"""
|
||||
if y.ndim > 1 and y.shape[1] > 1:
|
||||
# if predict_proba return np.array of floats
|
||||
y_ordered = np.zeros(y.shape, dtype=float)
|
||||
@@ -229,7 +301,15 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
return y_ordered
|
||||
|
||||
def predict(self, X: np.array) -> np.array:
|
||||
def predict_class(xp: np.array, indices: np.array, node: Snode) -> np.array:
|
||||
"""Predict labels for each sample in dataset passed
|
||||
|
||||
:param X: dataset of samples
|
||||
:type X: np.array
|
||||
:return: array of labels
|
||||
:rtype: np.array
|
||||
"""
|
||||
def predict_class(xp: np.array, indices: np.array,
|
||||
node: Snode) -> np.array:
|
||||
if xp is None:
|
||||
return [], []
|
||||
if node.is_leaf():
|
||||
@@ -242,29 +322,36 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
prx_u, prin_u = predict_class(X_U, i_u, node.get_up())
|
||||
prx_d, prin_d = predict_class(X_D, i_d, node.get_down())
|
||||
return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
|
||||
|
||||
# sklearn check
|
||||
check_is_fitted(self, ['tree_'])
|
||||
# Input validation
|
||||
X = check_array(X)
|
||||
# setup prediction & make it happen
|
||||
indices = np.arange(X.shape[0])
|
||||
return self._reorder_results(*predict_class(X, indices, self.tree_)).ravel()
|
||||
result = self._reorder_results(
|
||||
*predict_class(X, indices, self.tree_)).astype(int).ravel()
|
||||
return self.classes_[result]
|
||||
|
||||
def predict_proba(self, X: np.array) -> np.array:
|
||||
"""Computes an approximation of the probability of samples belonging to class 0 and 1
|
||||
"""Computes an approximation of the probability of samples belonging to
|
||||
class 0 and 1
|
||||
:param X: dataset
|
||||
:type X: np.array
|
||||
:return: array array of shape (m, num_classes), probability of being
|
||||
each class
|
||||
:rtype: np.array
|
||||
"""
|
||||
|
||||
def predict_class(xp: np.array, indices: np.array, dist: np.array, node: Snode) -> np.array:
|
||||
def predict_class(xp: np.array, indices: np.array, dist: np.array,
|
||||
node: Snode) -> np.array:
|
||||
"""Run the tree to compute predictions
|
||||
|
||||
:param xp: subdataset of samples
|
||||
:type xp: np.array
|
||||
:param indices: indices of subdataset samples to rebuild original order
|
||||
:param indices: indices of subdataset samples to rebuild original
|
||||
order
|
||||
:type indices: np.array
|
||||
:param dist: distances of every sample to the hyperplane or the father node
|
||||
:param dist: distances of every sample to the hyperplane or the
|
||||
father node
|
||||
:type dist: np.array
|
||||
:param node: node of the leaf with the class
|
||||
:type node: Snode
|
||||
@@ -280,7 +367,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
return np.append(prediction, prediction_proba, axis=1), indices
|
||||
distances = self._distances(node, xp)
|
||||
down = self._split_criteria(distances)
|
||||
|
||||
X_U, X_D = self._split_array(xp, down)
|
||||
i_u, i_d = self._split_array(indices, down)
|
||||
di_u, di_d = self._split_array(distances, down)
|
||||
@@ -297,15 +383,24 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
empty_dist = np.empty((X.shape[0], 1), dtype=float)
|
||||
result, indices = predict_class(X, indices, empty_dist, self.tree_)
|
||||
result = result.reshape(X.shape[0], 2)
|
||||
# Turn distances to hyperplane into probabilities based on fitting distances
|
||||
# of samples to its hyperplane that classified them, to the sigmoid function
|
||||
# Turn distances to hyperplane into probabilities based on fitting
|
||||
# distances of samples to its hyperplane that classified them, to the
|
||||
# sigmoid function
|
||||
# Probability of being 1
|
||||
result[:, 1] = 1 / (1 + np.exp(-result[:, 1]))
|
||||
result[:, 0] = 1 - result[:, 1] # Probability of being 0
|
||||
# Probability of being 0
|
||||
result[:, 0] = 1 - result[:, 1]
|
||||
return self._reorder_results(result, indices)
|
||||
|
||||
def score(self, X: np.array, y: np.array) -> float:
|
||||
"""Return accuracy
|
||||
"""Compute accuracy of the prediction
|
||||
|
||||
:param X: dataset of samples to make predictions
|
||||
:type X: np.array
|
||||
:param y: samples labels
|
||||
:type y: np.array
|
||||
:return: accuracy of the prediction
|
||||
:rtype: float
|
||||
"""
|
||||
# sklearn check
|
||||
check_is_fitted(self)
|
||||
@@ -313,15 +408,25 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
return np.mean(yp == y)
|
||||
|
||||
def __iter__(self) -> Siterator:
|
||||
"""Create an iterator to be able to visit the nodes of the tree in preorder,
|
||||
can make a list with all the nodes in preorder
|
||||
|
||||
:return: an iterator, can for i in... and list(...)
|
||||
:rtype: Siterator
|
||||
"""
|
||||
try:
|
||||
tree = self.tree_
|
||||
except:
|
||||
except AttributeError:
|
||||
tree = None
|
||||
return Siterator(tree)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the tree
|
||||
|
||||
:return: description of nodes in the tree in preorder
|
||||
:rtype: str
|
||||
"""
|
||||
output = ''
|
||||
for i in self:
|
||||
output += str(i) + '\n'
|
||||
return output
|
||||
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import csv
|
||||
import os
|
||||
import unittest
|
||||
|
||||
@@ -22,18 +21,22 @@ class Stree_test(unittest.TestCase):
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
os.environ.pop('TESTING')
|
||||
except:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def _get_Xy(self):
|
||||
X, y = make_classification(n_samples=1500, n_features=3, n_informative=3,
|
||||
n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=2,
|
||||
class_sep=1.5, flip_y=0, weights=[0.5, 0.5], random_state=self._random_state)
|
||||
X, y = make_classification(n_samples=1500, n_features=3,
|
||||
n_informative=3, n_redundant=0,
|
||||
n_repeated=0, n_classes=2,
|
||||
n_clusters_per_class=2, class_sep=1.5,
|
||||
flip_y=0, weights=[0.5, 0.5],
|
||||
random_state=self._random_state)
|
||||
return X, y
|
||||
|
||||
def _check_tree(self, node: Snode):
|
||||
"""Check recursively that the nodes that are not leaves have the correct
|
||||
number of labels and its sons have the right number of elements in their dataset
|
||||
"""Check recursively that the nodes that are not leaves have the
|
||||
correct number of labels and its sons have the right number of elements
|
||||
in their dataset
|
||||
|
||||
Arguments:
|
||||
node {Snode} -- node to check
|
||||
@@ -53,11 +56,11 @@ class Stree_test(unittest.TestCase):
|
||||
for i in unique_y:
|
||||
try:
|
||||
number_down = count_d[i]
|
||||
except:
|
||||
except IndexError:
|
||||
number_down = 0
|
||||
try:
|
||||
number_up = count_u[i]
|
||||
except:
|
||||
except IndexError:
|
||||
number_up = 0
|
||||
self.assertEqual(count_y[i], number_down + number_up)
|
||||
# Is the partition made the same as the prediction?
|
||||
@@ -89,7 +92,8 @@ class Stree_test(unittest.TestCase):
|
||||
fx = np.delete(data, column_y, axis=1)
|
||||
return fx, fy
|
||||
|
||||
def _find_out(self, px: np.array, x_original: np.array, y_original) -> list:
|
||||
def _find_out(self, px: np.array, x_original: np.array,
|
||||
y_original) -> list:
|
||||
"""Find the original values of y for a given array of samples
|
||||
|
||||
Arguments:
|
||||
@@ -128,16 +132,18 @@ class Stree_test(unittest.TestCase):
|
||||
self.assertGreater(accuracy_score, 0.9)
|
||||
|
||||
def test_single_predict_proba(self):
|
||||
"""Check that element 28 has a prediction different that the current label
|
||||
"""Check that element 28 has a prediction different that the current
|
||||
label
|
||||
"""
|
||||
# Element 28 has a different prediction than the truth
|
||||
decimals = 5
|
||||
prob = 0.29026400766
|
||||
X, y = self._get_Xy()
|
||||
yp = self._clf.predict_proba(X[28, :].reshape(-1, X.shape[1]))
|
||||
self.assertEqual(np.round(1 - prob, decimals), np.round(yp[0:, 0], decimals))
|
||||
self.assertEqual(np.round(1 - prob, decimals),
|
||||
np.round(yp[0:, 0], decimals))
|
||||
self.assertEqual(1, y[28])
|
||||
|
||||
|
||||
self.assertAlmostEqual(
|
||||
round(prob, decimals),
|
||||
round(yp[0, 1], decimals),
|
||||
@@ -150,11 +156,16 @@ class Stree_test(unittest.TestCase):
|
||||
decimals = 5
|
||||
X, y = self._get_Xy()
|
||||
yp = self._clf.predict_proba(X[:num, :])
|
||||
self.assertListEqual(y[:num].tolist(), np.argmax(yp[:num], axis=1).tolist())
|
||||
expected_proba = [0.88395641, 0.36746962, 0.84158767, 0.34106833, 0.14269291, 0.85193236,
|
||||
0.29876058, 0.7282164, 0.85958616, 0.89517877, 0.99745224, 0.18860349,
|
||||
0.30756427, 0.8318412, 0.18981198, 0.15564624, 0.25740655, 0.22923355,
|
||||
0.87365959, 0.49928689, 0.95574351, 0.28761257, 0.28906333, 0.32643692,
|
||||
self.assertListEqual(
|
||||
y[:num].tolist(), np.argmax(yp[:num], axis=1).tolist())
|
||||
expected_proba = [0.88395641, 0.36746962, 0.84158767, 0.34106833,
|
||||
0.14269291, 0.85193236,
|
||||
0.29876058, 0.7282164, 0.85958616, 0.89517877,
|
||||
0.99745224, 0.18860349,
|
||||
0.30756427, 0.8318412, 0.18981198, 0.15564624,
|
||||
0.25740655, 0.22923355,
|
||||
0.87365959, 0.49928689, 0.95574351, 0.28761257,
|
||||
0.28906333, 0.32643692,
|
||||
0.29788483, 0.01657364, 0.81149083]
|
||||
expected = np.round(expected_proba, decimals=decimals).tolist()
|
||||
computed = np.round(yp[:, 1], decimals=decimals).tolist()
|
||||
@@ -162,9 +173,10 @@ class Stree_test(unittest.TestCase):
|
||||
self.assertAlmostEqual(expected[i], computed[i], decimals)
|
||||
|
||||
def build_models(self):
|
||||
"""Build and train two models, model_clf will use the sklearn classifier to
|
||||
compute predictions and split data. model_computed will use vector of
|
||||
coefficients to compute both predictions and splitted data
|
||||
"""Build and train two models, model_clf will use the sklearn
|
||||
classifier to compute predictions and split data. model_computed will
|
||||
use vector of coefficients to compute both predictions and splitted
|
||||
data
|
||||
"""
|
||||
model_clf = Stree(random_state=self._random_state,
|
||||
use_predictions=True)
|
||||
@@ -176,8 +188,9 @@ class Stree_test(unittest.TestCase):
|
||||
return model_clf, model_computed, X, y
|
||||
|
||||
def test_use_model_predict(self):
|
||||
"""Check that we get the same results wether we use the estimator in nodes
|
||||
to compute labels or we use the hyperplane and the position of samples wrt to it
|
||||
"""Check that we get the same results wether we use the estimator in
|
||||
nodes to compute labels or we use the hyperplane and the position of
|
||||
samples wrt to it
|
||||
"""
|
||||
use_clf, use_math, X, _ = self.build_models()
|
||||
self.assertListEqual(
|
||||
@@ -202,14 +215,15 @@ class Stree_test(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_single_vs_multiple_prediction(self):
|
||||
"""Check if predicting sample by sample gives the same result as predicting
|
||||
all samples at once
|
||||
"""Check if predicting sample by sample gives the same result as
|
||||
predicting all samples at once
|
||||
"""
|
||||
X, _ = self._get_Xy()
|
||||
# Compute prediction line by line
|
||||
yp_line = np.array([], dtype=int)
|
||||
for xp in X:
|
||||
yp_line = np.append(yp_line, self._clf.predict(xp.reshape(-1, X.shape[1])))
|
||||
yp_line = np.append(yp_line, self._clf.predict(
|
||||
xp.reshape(-1, X.shape[1])))
|
||||
# Compute prediction at once
|
||||
yp_once = self._clf.predict(X)
|
||||
#
|
||||
@@ -221,11 +235,15 @@ class Stree_test(unittest.TestCase):
|
||||
expected = [
|
||||
'root',
|
||||
'root - Down',
|
||||
'root - Down - Down, <cgaf> - Leaf class=1 belief=0.975989 counts=(array([0, 1]), array([ 17, 691]))',
|
||||
'root - Down - Down, <cgaf> - Leaf class=1 belief= 0.975989 counts'
|
||||
'=(array([0, 1]), array([ 17, 691]))',
|
||||
'root - Down - Up',
|
||||
'root - Down - Up - Down, <cgaf> - Leaf class=1 belief=0.750000 counts=(array([0, 1]), array([1, 3]))',
|
||||
'root - Down - Up - Up, <pure> - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))',
|
||||
'root - Up, <cgaf> - Leaf class=0 belief=0.928297 counts=(array([0, 1]), array([725, 56]))',
|
||||
'root - Down - Up - Down, <cgaf> - Leaf class=1 belief= 0.750000 '
|
||||
'counts=(array([0, 1]), array([1, 3]))',
|
||||
'root - Down - Up - Up, <pure> - Leaf class=0 belief= 1.000000 '
|
||||
'counts=(array([0]), array([7]))',
|
||||
'root - Up, <cgaf> - Leaf class=0 belief= 0.928297 counts=(array('
|
||||
'[0, 1]), array([725, 56]))',
|
||||
]
|
||||
computed = []
|
||||
for node in self._clf:
|
||||
@@ -253,10 +271,10 @@ class Stree_test(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
tcl = Stree(max_depth=-1)
|
||||
tcl.fit(*self._get_Xy())
|
||||
|
||||
|
||||
def test_check_max_depth(self):
|
||||
depth = 3
|
||||
tcl = Stree(random_state=self._random_state, max_depth=depth)
|
||||
tcl = Stree(random_state=self._random_state, max_depth=depth)
|
||||
tcl.fit(*self._get_Xy())
|
||||
self.assertEqual(depth, tcl.depth_)
|
||||
|
||||
@@ -264,6 +282,7 @@ class Stree_test(unittest.TestCase):
|
||||
tcl = Stree()
|
||||
self.assertEqual(0, len(list(tcl)))
|
||||
|
||||
|
||||
class Snode_test(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -276,19 +295,24 @@ class Snode_test(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
"""[summary]
|
||||
"""
|
||||
try:
|
||||
os.environ.pop('TESTING')
|
||||
except:
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def _get_Xy(self):
|
||||
X, y = make_classification(n_samples=1500, n_features=3, n_informative=3,
|
||||
n_redundant=0, n_repeated=0, n_classes=2, n_clusters_per_class=2,
|
||||
class_sep=1.5, flip_y=0, weights=[0.5, 0.5], random_state=self._random_state)
|
||||
X, y = make_classification(n_samples=1500, n_features=3,
|
||||
n_informative=3, n_redundant=0, n_classes=2,
|
||||
n_repeated=0, n_clusters_per_class=2,
|
||||
class_sep=1.5, flip_y=0, weights=[0.5, 0.5],
|
||||
random_state=self._random_state)
|
||||
return X, y
|
||||
|
||||
def test_attributes_in_leaves(self):
|
||||
"""Check if the attributes in leaves have correct values so they form a predictor
|
||||
"""Check if the attributes in leaves have correct values so they form a
|
||||
predictor
|
||||
"""
|
||||
|
||||
def check_leave(node: Snode):
|
||||
@@ -303,7 +327,7 @@ class Snode_test(unittest.TestCase):
|
||||
if len(classes) > 1:
|
||||
try:
|
||||
belief = max_card / (max_card + min_card)
|
||||
except:
|
||||
except ZeroDivisionError:
|
||||
belief = 0.
|
||||
else:
|
||||
belief = 1
|
||||
|
Reference in New Issue
Block a user