mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
Merge pull request #9 from Doctorado-ML/add_multiclass
#6 Add multiclass
This commit is contained in:
@@ -3,9 +3,6 @@ overage:
|
||||
project:
|
||||
default:
|
||||
target: 90%
|
||||
patch:
|
||||
default:
|
||||
target: 90%
|
||||
comment:
|
||||
layout: "reach, diff, flags, files"
|
||||
behavior: default
|
||||
|
11
main.py
11
main.py
@@ -75,14 +75,3 @@ print(f"Took {time.time() - now:.2f} seconds to train")
|
||||
print(clf)
|
||||
print(f"Classifier's accuracy (train): {clf.score(Xtrain, ytrain):.4f}")
|
||||
print(f"Classifier's accuracy (test) : {clf.score(Xtest, ytest):.4f}")
|
||||
proba = clf.predict_proba(Xtest)
|
||||
print(
|
||||
"Checking that we have correct probabilities, these are probabilities of "
|
||||
"sample belonging to class 1"
|
||||
)
|
||||
res0 = proba[proba[:, 0] == 0]
|
||||
res1 = proba[proba[:, 0] == 1]
|
||||
print("++++++++++res0 > .8++++++++++++")
|
||||
print(res0[res0[:, 1] > 0.8])
|
||||
print("**********res1 < .4************")
|
||||
print(res1[res1[:, 1] < 0.4])
|
||||
|
242
stree/Strees.py
242
stree/Strees.py
@@ -19,7 +19,6 @@ from sklearn.utils.validation import (
|
||||
check_is_fitted,
|
||||
_check_sample_weight,
|
||||
)
|
||||
from sklearn.utils.sparsefuncs import count_nonzero
|
||||
from sklearn.metrics._classification import _weighted_sum, _check_targets
|
||||
|
||||
|
||||
@@ -131,6 +130,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
tol: float = 1e-4,
|
||||
degree: int = 3,
|
||||
gamma="scale",
|
||||
split_criteria="max_samples",
|
||||
min_samples_split: int = 0,
|
||||
):
|
||||
self.max_iter = max_iter
|
||||
@@ -142,17 +142,18 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self.gamma = gamma
|
||||
self.degree = degree
|
||||
self.min_samples_split = min_samples_split
|
||||
self.split_criteria = split_criteria
|
||||
|
||||
def _more_tags(self) -> dict:
|
||||
"""Required by sklearn to tell that this estimator is a binary classifier
|
||||
"""Required by sklearn to supply features of the classifier
|
||||
|
||||
:return: the tag required
|
||||
:rtype: dict
|
||||
"""
|
||||
return {"binary_only": True, "requires_y": True}
|
||||
return {"requires_y": True}
|
||||
|
||||
def _split_array(self, origin: np.array, down: np.array) -> list:
|
||||
"""Split an array in two based on indices passed as down and its complement
|
||||
"""Split an array in two based on indices (down) and its complement
|
||||
|
||||
:param origin: dataset to split
|
||||
:type origin: np.array
|
||||
@@ -163,8 +164,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
"""
|
||||
up = ~down
|
||||
return (
|
||||
origin[up[:, 0]] if any(up) else None,
|
||||
origin[down[:, 0]] if any(down) else None,
|
||||
origin[up] if any(up) else None,
|
||||
origin[down] if any(down) else None,
|
||||
)
|
||||
|
||||
def _distances(self, node: Snode, data: np.ndarray) -> np.array:
|
||||
@@ -178,27 +179,38 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
the hyperplane of the node
|
||||
:rtype: np.array
|
||||
"""
|
||||
res = node._clf.decision_function(data)
|
||||
if res.ndim == 1:
|
||||
return np.expand_dims(res, 1)
|
||||
elif res.shape[1] > 1:
|
||||
# remove multiclass info
|
||||
res = np.delete(res, slice(1, res.shape[1]), axis=1)
|
||||
return res
|
||||
return node._clf.decision_function(data)
|
||||
|
||||
def _split_criteria(self, data: np.array) -> np.array:
|
||||
def _min_distance(self, data: np.array, _) -> np.array:
|
||||
# chooses the lowest distance of every sample
|
||||
indices = np.argmin(np.abs(data), axis=1)
|
||||
return np.take(data, indices)
|
||||
|
||||
def _max_samples(self, data: np.array, y: np.array) -> np.array:
|
||||
# select the class with max number of samples
|
||||
_, samples = np.unique(y, return_counts=True)
|
||||
selected = np.argmax(samples)
|
||||
return data[:, selected]
|
||||
|
||||
def _split_criteria(self, data: np.array, node: Snode) -> np.array:
|
||||
"""Set the criteria to split arrays
|
||||
|
||||
:param data: [description]
|
||||
:param data: distances of samples to hyperplanes shape (m, nclasses)
|
||||
if nclasses > 2 else (m,)
|
||||
:type data: np.array
|
||||
:return: [description]
|
||||
:param node: node containing the svm classifier
|
||||
:type node: Snode
|
||||
:return: array of booleans of samples under or above zero
|
||||
:rtype: np.array
|
||||
"""
|
||||
return (
|
||||
data > 0
|
||||
if data.shape[0] >= self.min_samples_split
|
||||
else np.ones((data.shape[0], 1), dtype=bool)
|
||||
)
|
||||
|
||||
if data.shape[0] < self.min_samples_split:
|
||||
return np.ones((data.shape[0]), dtype=bool)
|
||||
if data.ndim > 1:
|
||||
# split criteria for multiclass
|
||||
data = getattr(self, f"_{self.split_criteria}")(data, node._y)
|
||||
res = data > 0
|
||||
return res
|
||||
|
||||
def fit(
|
||||
self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None
|
||||
@@ -231,12 +243,19 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
f"Maximum depth has to be greater than 1... got (max_depth=\
|
||||
{self.max_depth})"
|
||||
)
|
||||
if self.split_criteria not in ["min_distance", "max_samples"]:
|
||||
raise ValueError(
|
||||
f"split_criteria has to be min_distance or \
|
||||
max_samples got ({self.split_criteria})"
|
||||
)
|
||||
|
||||
check_classification_targets(y)
|
||||
X, y = check_X_y(X, y)
|
||||
sample_weight = _check_sample_weight(sample_weight, X)
|
||||
check_classification_targets(y)
|
||||
# Initialize computed parameters
|
||||
self.classes_, y = np.unique(y, return_inverse=True)
|
||||
self.n_classes_ = self.classes_.shape[0]
|
||||
self.n_iter_ = self.max_iter
|
||||
self.depth_ = 0
|
||||
self.n_features_in_ = X.shape[1]
|
||||
@@ -244,6 +263,52 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
self._build_predictor()
|
||||
return self
|
||||
|
||||
def train(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
sample_weight: np.ndarray,
|
||||
depth: int,
|
||||
title: str,
|
||||
) -> Snode:
|
||||
"""Recursive function to split the original dataset into predictor
|
||||
nodes (leaves)
|
||||
|
||||
:param X: samples dataset
|
||||
:type X: np.ndarray
|
||||
:param y: samples labels
|
||||
:type y: np.ndarray
|
||||
:param sample_weight: weight of samples. Rescale C per sample.
|
||||
Hi weights force the classifier to put more emphasis on these points.
|
||||
:type sample_weight: np.ndarray
|
||||
:param depth: actual depth in the tree
|
||||
:type depth: int
|
||||
:param title: description of the node
|
||||
:type title: str
|
||||
:return: binary tree
|
||||
:rtype: Snode
|
||||
"""
|
||||
if depth > self.__max_depth:
|
||||
return None
|
||||
if np.unique(y).shape[0] == 1:
|
||||
# only 1 class => pure dataset
|
||||
return Snode(None, X, y, title + ", <pure>")
|
||||
# Train the model
|
||||
clf = self._build_clf()
|
||||
clf.fit(X, y, sample_weight=sample_weight)
|
||||
node = Snode(clf, X, y, title)
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
down = self._split_criteria(self._distances(node, X), node)
|
||||
X_U, X_D = self._split_array(X, down)
|
||||
y_u, y_d = self._split_array(y, down)
|
||||
sw_u, sw_d = self._split_array(sample_weight, down)
|
||||
if X_U is None or X_D is None:
|
||||
# didn't part anything
|
||||
return Snode(clf, X, y, title + ", <cgaf>")
|
||||
node.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up"))
|
||||
node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
|
||||
return node
|
||||
|
||||
def _build_predictor(self):
|
||||
"""Process the leaves to make them predictors
|
||||
"""
|
||||
@@ -278,52 +343,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
)
|
||||
)
|
||||
|
||||
def train(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
sample_weight: np.ndarray,
|
||||
depth: int,
|
||||
title: str,
|
||||
) -> Snode:
|
||||
"""Recursive function to split the original dataset into predictor
|
||||
nodes (leaves)
|
||||
|
||||
:param X: samples dataset
|
||||
:type X: np.ndarray
|
||||
:param y: samples labels
|
||||
:type y: np.ndarray
|
||||
:param sample_weight: weight of samples. Rescale C per sample.
|
||||
Hi weights force the classifier to put more emphasis on these points.
|
||||
:type sample_weight: np.ndarray
|
||||
:param depth: actual depth in the tree
|
||||
:type depth: int
|
||||
:param title: description of the node
|
||||
:type title: str
|
||||
:return: binary tree
|
||||
:rtype: Snode
|
||||
"""
|
||||
if depth > self.__max_depth:
|
||||
return None
|
||||
if np.unique(y).shape[0] == 1:
|
||||
# only 1 class => pure dataset
|
||||
return Snode(None, X, y, title + ", <pure>")
|
||||
# Train the model
|
||||
clf = self._build_clf()
|
||||
clf.fit(X, y, sample_weight=sample_weight)
|
||||
tree = Snode(clf, X, y, title)
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
down = self._split_criteria(self._distances(tree, X))
|
||||
X_U, X_D = self._split_array(X, down)
|
||||
y_u, y_d = self._split_array(y, down)
|
||||
sw_u, sw_d = self._split_array(sample_weight, down)
|
||||
if X_U is None or X_D is None:
|
||||
# didn't part anything
|
||||
return Snode(clf, X, y, title + ", <cgaf>")
|
||||
tree.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up"))
|
||||
tree.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
|
||||
return tree
|
||||
|
||||
def _reorder_results(self, y: np.array, indices: np.array) -> np.array:
|
||||
"""Reorder an array based on the array of indices passed
|
||||
|
||||
@@ -334,12 +353,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
:return: array y ordered
|
||||
:rtype: np.array
|
||||
"""
|
||||
if y.ndim > 1 and y.shape[1] > 1:
|
||||
# if predict_proba return np.array of floats
|
||||
y_ordered = np.zeros(y.shape, dtype=float)
|
||||
else:
|
||||
# return array of same type given in y
|
||||
y_ordered = y.copy()
|
||||
# return array of same type given in y
|
||||
y_ordered = y.copy()
|
||||
indices = indices.astype(int)
|
||||
for i, index in enumerate(indices):
|
||||
y_ordered[index] = y[i]
|
||||
@@ -363,11 +378,11 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
# set a class for every sample in dataset
|
||||
prediction = np.full((xp.shape[0], 1), node._class)
|
||||
return prediction, indices
|
||||
down = self._split_criteria(self._distances(node, xp))
|
||||
X_U, X_D = self._split_array(xp, down)
|
||||
down = self._split_criteria(self._distances(node, xp), node)
|
||||
x_u, x_d = self._split_array(xp, down)
|
||||
i_u, i_d = self._split_array(indices, down)
|
||||
prx_u, prin_u = predict_class(X_U, i_u, node.get_up())
|
||||
prx_d, prin_d = predict_class(X_D, i_d, node.get_down())
|
||||
prx_u, prin_u = predict_class(x_u, i_u, node.get_up())
|
||||
prx_d, prin_d = predict_class(x_d, i_d, node.get_down())
|
||||
return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
|
||||
|
||||
# sklearn check
|
||||
@@ -383,68 +398,6 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
)
|
||||
return self.classes_[result]
|
||||
|
||||
def predict_proba(self, X: np.array) -> np.array:
|
||||
"""Computes an approximation of the probability of samples belonging to
|
||||
class 0 and 1
|
||||
:param X: dataset
|
||||
:type X: np.array
|
||||
:return: array array of shape (m, num_classes), probability of being
|
||||
each class
|
||||
:rtype: np.array
|
||||
"""
|
||||
|
||||
def predict_class(
|
||||
xp: np.array, indices: np.array, dist: np.array, node: Snode
|
||||
) -> np.array:
|
||||
"""Run the tree to compute predictions
|
||||
|
||||
:param xp: subdataset of samples
|
||||
:type xp: np.array
|
||||
:param indices: indices of subdataset samples to rebuild original
|
||||
order
|
||||
:type indices: np.array
|
||||
:param dist: distances of every sample to the hyperplane or the
|
||||
father node
|
||||
:type dist: np.array
|
||||
:param node: node of the leaf with the class
|
||||
:type node: Snode
|
||||
:return: array of labels and distances, array of indices
|
||||
:rtype: np.array
|
||||
"""
|
||||
if xp is None:
|
||||
return [], []
|
||||
if node.is_leaf():
|
||||
# set a class for every sample in dataset
|
||||
prediction = np.full((xp.shape[0], 1), node._class)
|
||||
prediction_proba = dist
|
||||
return np.append(prediction, prediction_proba, axis=1), indices
|
||||
distances = self._distances(node, xp)
|
||||
down = self._split_criteria(distances)
|
||||
X_U, X_D = self._split_array(xp, down)
|
||||
i_u, i_d = self._split_array(indices, down)
|
||||
di_u, di_d = self._split_array(distances, down)
|
||||
prx_u, prin_u = predict_class(X_U, i_u, di_u, node.get_up())
|
||||
prx_d, prin_d = predict_class(X_D, i_d, di_d, node.get_down())
|
||||
return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
|
||||
|
||||
# sklearn check
|
||||
check_is_fitted(self, ["tree_"])
|
||||
# Input validation
|
||||
X = check_array(X)
|
||||
# setup prediction & make it happen
|
||||
indices = np.arange(X.shape[0])
|
||||
empty_dist = np.empty((X.shape[0], 1), dtype=float)
|
||||
result, indices = predict_class(X, indices, empty_dist, self.tree_)
|
||||
result = result.reshape(X.shape[0], 2)
|
||||
# Turn distances to hyperplane into probabilities based on fitting
|
||||
# distances of samples to its hyperplane that classified them, to the
|
||||
# sigmoid function
|
||||
# Probability of being 1
|
||||
result[:, 1] = 1 / (1 + np.exp(-result[:, 1]))
|
||||
# Probability of being 0
|
||||
result[:, 0] = 1 - result[:, 1]
|
||||
return self._reorder_results(result, indices)
|
||||
|
||||
def score(
|
||||
self, X: np.array, y: np.array, sample_weight: np.array = None
|
||||
) -> float:
|
||||
@@ -468,17 +421,12 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
# Compute accuracy for each possible representation
|
||||
y_type, y_true, y_pred = _check_targets(y, y_pred)
|
||||
check_consistent_length(y_true, y_pred, sample_weight)
|
||||
if y_type.startswith("multilabel"):
|
||||
differing_labels = count_nonzero(y_true - y_pred, axis=1)
|
||||
score = differing_labels == 0
|
||||
else:
|
||||
score = y_true == y_pred
|
||||
|
||||
score = y_true == y_pred
|
||||
return _weighted_sum(score, sample_weight, normalize=True)
|
||||
|
||||
def __iter__(self) -> Siterator:
|
||||
"""Create an iterator to be able to visit the nodes of the tree in preorder,
|
||||
can make a list with all the nodes in preorder
|
||||
"""Create an iterator to be able to visit the nodes of the tree in
|
||||
preorder, can make a list with all the nodes in preorder
|
||||
|
||||
:return: an iterator, can for i in... and list(...)
|
||||
:rtype: Siterator
|
||||
|
@@ -68,6 +68,11 @@ class Stree_grapher_test(unittest.TestCase):
|
||||
self.assertEqual(accuracy_score, accuracy_computed)
|
||||
self.assertGreater(accuracy_score, 0.86)
|
||||
|
||||
def test_score_4dims(self):
|
||||
X, y = get_dataset(self._random_state, n_features=4)
|
||||
accuracy_score = self._clf.score(X, y)
|
||||
self.assertEqual(accuracy_score, 0.95)
|
||||
|
||||
def test_save_all(self):
|
||||
folder_name = os.path.join(os.sep, "tmp", "stree")
|
||||
if os.path.isdir(folder_name):
|
||||
@@ -171,11 +176,13 @@ class Snode_graph_test(unittest.TestCase):
|
||||
|
||||
def test_plot_hyperplane_with_distribution(self):
|
||||
plt.close()
|
||||
# select a pure node
|
||||
node = self._clf._tree_gr.get_down().get_up().get_up()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
matplotlib.use("Agg")
|
||||
num_figures_before = plt.gcf().number
|
||||
self._clf._tree_gr.plot_hyperplane(plot_distribution=True)
|
||||
node.plot_hyperplane(plot_distribution=True)
|
||||
num_figures_after = plt.gcf().number
|
||||
self.assertEqual(1, num_figures_after - num_figures_before)
|
||||
|
||||
@@ -209,3 +216,11 @@ class Snode_graph_test(unittest.TestCase):
|
||||
self.assertEqual(x, xx)
|
||||
self.assertEqual(y, yy)
|
||||
self.assertEqual(z, zz)
|
||||
|
||||
def test_cmap_change(self):
|
||||
node = Snode_graph(Snode(None, None, None, "test"))
|
||||
self.assertEqual("jet", node._get_cmap())
|
||||
# make node pure
|
||||
node._belief = 1.0
|
||||
node._class = 1
|
||||
self.assertEqual("jet_r", node._get_cmap())
|
||||
|
@@ -2,23 +2,22 @@ import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from sklearn.datasets import make_classification
|
||||
from sklearn.datasets import make_classification, load_iris
|
||||
|
||||
from stree import Stree, Snode
|
||||
|
||||
|
||||
def get_dataset(random_state=0):
|
||||
def get_dataset(random_state=0, n_classes=2):
|
||||
X, y = make_classification(
|
||||
n_samples=1500,
|
||||
n_features=3,
|
||||
n_informative=3,
|
||||
n_redundant=0,
|
||||
n_repeated=0,
|
||||
n_classes=2,
|
||||
n_classes=n_classes,
|
||||
n_clusters_per_class=2,
|
||||
class_sep=1.5,
|
||||
flip_y=0,
|
||||
weights=[0.5, 0.5],
|
||||
random_state=random_state,
|
||||
)
|
||||
return X, y
|
||||
@@ -104,9 +103,8 @@ class Stree_test(unittest.TestCase):
|
||||
return res
|
||||
|
||||
def test_single_prediction(self):
|
||||
probs = [0.29026400766, 0.73105613, 0.0307635]
|
||||
X, y = get_dataset(self._random_state)
|
||||
for kernel, prob in zip(self._kernels, probs):
|
||||
for kernel in self._kernels:
|
||||
clf = Stree(kernel=kernel, random_state=self._random_state)
|
||||
yp = clf.fit(X, y).predict((X[0, :].reshape(-1, X.shape[1])))
|
||||
self.assertEqual(yp[0], y[0])
|
||||
@@ -122,10 +120,12 @@ class Stree_test(unittest.TestCase):
|
||||
|
||||
def test_score(self):
|
||||
X, y = get_dataset(self._random_state)
|
||||
for kernel, accuracy_expected in zip(
|
||||
self._kernels,
|
||||
[0.9506666666666667, 0.9606666666666667, 0.9433333333333334],
|
||||
):
|
||||
accuracies = [
|
||||
0.9506666666666667,
|
||||
0.9606666666666667,
|
||||
0.9433333333333334,
|
||||
]
|
||||
for kernel, accuracy_expected in zip(self._kernels, accuracies):
|
||||
clf = Stree(random_state=self._random_state, kernel=kernel,)
|
||||
clf.fit(X, y)
|
||||
accuracy_score = clf.score(X, y)
|
||||
@@ -134,38 +134,6 @@ class Stree_test(unittest.TestCase):
|
||||
self.assertEqual(accuracy_score, accuracy_computed)
|
||||
self.assertAlmostEqual(accuracy_expected, accuracy_score)
|
||||
|
||||
def test_single_predict_proba(self):
|
||||
"""Check the element 28 probability of being 1
|
||||
"""
|
||||
decimals = 5
|
||||
element = 28
|
||||
probs = [0.29026400766, 0.73105613, 0.0307635]
|
||||
X, y = get_dataset(self._random_state)
|
||||
self.assertEqual(1, y[element])
|
||||
for kernel, prob in zip(self._kernels, probs):
|
||||
clf = Stree(kernel=kernel, random_state=self._random_state)
|
||||
yp = clf.fit(X, y).predict_proba(
|
||||
X[element, :].reshape(-1, X.shape[1])
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
np.round(1 - prob, decimals), np.round(yp[0:, 0], decimals)
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
round(prob, decimals), round(yp[0, 1], decimals), decimals
|
||||
)
|
||||
|
||||
def test_multiple_predict_proba(self):
|
||||
# First 27 elements the predictions are the same as the truth
|
||||
num = 27
|
||||
X, y = get_dataset(self._random_state)
|
||||
for kernel in self._kernels:
|
||||
clf = Stree(kernel=kernel, random_state=self._random_state)
|
||||
clf.fit(X, y)
|
||||
yp = clf.predict_proba(X[:num, :])
|
||||
self.assertListEqual(
|
||||
y[:num].tolist(), np.argmax(yp[:num], axis=1).tolist()
|
||||
)
|
||||
|
||||
def test_single_vs_multiple_prediction(self):
|
||||
"""Check if predicting sample by sample gives the same result as
|
||||
predicting all samples at once
|
||||
@@ -225,6 +193,11 @@ class Stree_test(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
tclf.fit(*get_dataset(self._random_state))
|
||||
|
||||
def test_exception_if_bogus_split_criteria(self):
|
||||
tclf = Stree(split_criteria="duck")
|
||||
with self.assertRaises(ValueError):
|
||||
tclf.fit(*get_dataset(self._random_state))
|
||||
|
||||
def test_check_max_depth_is_positive_or_None(self):
|
||||
tcl = Stree()
|
||||
self.assertIsNone(tcl.max_depth)
|
||||
@@ -256,14 +229,56 @@ class Stree_test(unittest.TestCase):
|
||||
self.assertIsNone(tcl_nosplit.tree_.get_down())
|
||||
self.assertIsNone(tcl_nosplit.tree_.get_up())
|
||||
|
||||
def test_muticlass_dataset(self):
|
||||
def test_simple_muticlass_dataset(self):
|
||||
for kernel in self._kernels:
|
||||
clf = Stree(kernel=kernel, random_state=self._random_state)
|
||||
px = [[1, 2], [3, 4], [5, 6]]
|
||||
py = [1, 2, 3]
|
||||
clf = Stree(
|
||||
kernel=kernel,
|
||||
split_criteria="max_samples",
|
||||
random_state=self._random_state,
|
||||
)
|
||||
px = [[1, 2], [5, 6], [9, 10]]
|
||||
py = [0, 1, 2]
|
||||
clf.fit(px, py)
|
||||
self.assertEqual(1.0, clf.score(px, py))
|
||||
self.assertListEqual([1, 2, 3], clf.predict(px).tolist())
|
||||
self.assertListEqual(py, clf.predict(px).tolist())
|
||||
self.assertListEqual(py, clf.classes_.tolist())
|
||||
|
||||
def test_muticlass_dataset(self):
|
||||
datasets = {
|
||||
"Synt": get_dataset(random_state=self._random_state, n_classes=3),
|
||||
"Iris": load_iris(return_X_y=True),
|
||||
}
|
||||
outcomes = {
|
||||
"Synt": {
|
||||
"max_samples linear": 0.9533333333333334,
|
||||
"max_samples rbf": 0.836,
|
||||
"max_samples poly": 0.9473333333333334,
|
||||
"min_distance linear": 0.9533333333333334,
|
||||
"min_distance rbf": 0.836,
|
||||
"min_distance poly": 0.9473333333333334,
|
||||
},
|
||||
"Iris": {
|
||||
"max_samples linear": 0.98,
|
||||
"max_samples rbf": 1.0,
|
||||
"max_samples poly": 1.0,
|
||||
"min_distance linear": 0.98,
|
||||
"min_distance rbf": 1.0,
|
||||
"min_distance poly": 1.0,
|
||||
},
|
||||
}
|
||||
for name, dataset in datasets.items():
|
||||
px, py = dataset
|
||||
for criteria in ["max_samples", "min_distance"]:
|
||||
for kernel in self._kernels:
|
||||
clf = Stree(
|
||||
C=1e4,
|
||||
max_iter=1e4,
|
||||
kernel=kernel,
|
||||
random_state=self._random_state,
|
||||
)
|
||||
clf.fit(px, py)
|
||||
outcome = outcomes[name][f"{criteria} {kernel}"]
|
||||
self.assertAlmostEqual(outcome, clf.score(px, py))
|
||||
|
||||
|
||||
class Snode_test(unittest.TestCase):
|
||||
|
Reference in New Issue
Block a user