Compare commits

...

2 Commits

Author SHA1 Message Date
Ricardo Montañana Gómez
c37f044e3a Update doc and version 1.30 (#55)
* Add complete classes counts to node and tests

* Implement optimized predict and new predict_proba

* Add predict_proba test

* Add python 3.10 to CI

* Update version number and documentation
2022-10-21 13:31:59 +02:00
Ricardo Montañana Gómez
2f6ae648a1 New predict proba (#53)
* Add complete classes counts to node and tests

* Implement optimized predict and new predict_proba

* Add predict_proba test

* Add python 3.10 to CI
2022-10-21 12:26:46 +02:00
10 changed files with 172 additions and 93 deletions

View File

@@ -13,7 +13,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: [macos-latest, ubuntu-latest, windows-latest] os: [macos-latest, ubuntu-latest, windows-latest]
python: [3.8] python: [3.8, "3.10"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2

View File

@@ -50,7 +50,8 @@ Can be found in [stree.readthedocs.io](https://stree.readthedocs.io/en/stable/)
| | criterion | {“gini”, “entropy”} | entropy | The function to measure the quality of a split (only used if max_features != num_features). <br>Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain. | | | criterion | {“gini”, “entropy”} | entropy | The function to measure the quality of a split (only used if max_features != num_features). <br>Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain. |
| | min_samples_split | \<int\> | 0 | The minimum number of samples required to split an internal node. 0 (default) for any | | | min_samples_split | \<int\> | 0 | The minimum number of samples required to split an internal node. 0 (default) for any |
| | max_features | \<int\>, \<float\> <br><br>or {“auto”, “sqrt”, “log2”} | None | The number of features to consider when looking for the split:<br>If int, then consider max_features features at each split.<br>If float, then max_features is a fraction and int(max_features \* n_features) features are considered at each split.<br>If “auto”, then max_features=sqrt(n_features).<br>If “sqrt”, then max_features=sqrt(n_features).<br>If “log2”, then max_features=log2(n_features).<br>If None, then max_features=n_features. | | | max_features | \<int\>, \<float\> <br><br>or {“auto”, “sqrt”, “log2”} | None | The number of features to consider when looking for the split:<br>If int, then consider max_features features at each split.<br>If float, then max_features is a fraction and int(max_features \* n_features) features are considered at each split.<br>If “auto”, then max_features=sqrt(n_features).<br>If “sqrt”, then max_features=sqrt(n_features).<br>If “log2”, then max_features=log2(n_features).<br>If None, then max_features=n_features. |
| | splitter | {"best", "random", "trandom", "mutual", "cfs", "fcbf", "iwss"} | "random" | The strategy used to choose the feature set at each node (only used if max_features < num_features). Supported strategies are: **best”**: sklearn SelectKBest algorithm is used in every node to choose the max_features best features. **random”**: The algorithm generates 5 candidates and choose the best (max. info. gain) of them. **trandom”**: The algorithm generates only one random combination. **"mutual"**: Chooses the best features w.r.t. their mutual info with the label. **"cfs"**: Apply Correlation-based Feature Selection. **"fcbf"**: Apply Fast Correlation-Based Filter. **"iwss"**: IWSS based algorithm | | | splitter | {"best", "random", "trandom", "mutual", "cfs", "fcbf", "iwss"} | "random" | The strategy used to choose the feature set at each node (only used if max_features < num_features).
Supported strategies are: **best”**: sklearn SelectKBest algorithm is used in every node to choose the max_features best features. **random”**: The algorithm generates 5 candidates and choose the best (max. info. gain) of them. **trandom”**: The algorithm generates only one random combination. **"mutual"**: Chooses the best features w.r.t. their mutual info with the label. **"cfs"**: Apply Correlation-based Feature Selection. **"fcbf"**: Apply Fast Correlation-Based Filter. **"iwss"**: IWSS based algorithm |
| | normalize | \<bool\> | False | If standardization of features should be applied on each node with the samples that reach it | | | normalize | \<bool\> | False | If standardization of features should be applied on each node with the samples that reach it |
| \* | multiclass_strategy | {"ovo", "ovr"} | "ovo" | Strategy to use with multiclass datasets, **"ovo"**: one versus one. **"ovr"**: one versus rest | | \* | multiclass_strategy | {"ovo", "ovr"} | "ovo" | Strategy to use with multiclass datasets, **"ovo"**: one versus one. **"ovr"**: one versus rest |

View File

@@ -12,19 +12,18 @@
# #
import os import os
import sys import sys
import stree from stree._version import __version__
sys.path.insert(0, os.path.abspath("../../stree/")) sys.path.insert(0, os.path.abspath("../../stree/"))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = "STree" project = "STree"
copyright = "2020 - 2021, Ricardo Montañana Gómez" copyright = "2020 - 2022, Ricardo Montañana Gómez"
author = "Ricardo Montañana Gómez" author = "Ricardo Montañana Gómez"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
version = stree.__version__ version = __version__
release = version release = version

View File

@@ -3,20 +3,20 @@
| | **Hyperparameter** | **Type/Values** | **Default** | **Meaning** | | | **Hyperparameter** | **Type/Values** | **Default** | **Meaning** |
| --- | ------------------- | -------------------------------------------------------------- | ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | --- | ------------------- | -------------------------------------------------------------- | ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| \* | C | \<float\> | 1.0 | Regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. | | \* | C | \<float\> | 1.0 | Regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. |
| \* | kernel | {"liblinear", "linear", "poly", "rbf", "sigmoid"} | linear | Specifies the kernel type to be used in the algorithm. It must be one of liblinear, linear, poly or rbf. liblinear uses [liblinear](https://www.csie.ntu.edu.tw/~cjlin/liblinear/) library and the rest uses [libsvm](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) library through scikit-learn library | | \* | kernel | {"liblinear", "linear", "poly", "rbf", "sigmoid"} | linear | Specifies the kernel type to be used in the algorithm. It must be one of liblinear, linear, poly or rbf.<br>liblinear uses [liblinear](https://www.csie.ntu.edu.tw/~cjlin/liblinear/) library and the rest uses [libsvm](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) library through scikit-learn library |
| \* | max_iter | \<int\> | 1e5 | Hard limit on iterations within solver, or -1 for no limit. | | \* | max_iter | \<int\> | 1e5 | Hard limit on iterations within solver, or -1 for no limit. |
| \* | random_state | \<int\> | None | Controls the pseudo random number generation for shuffling the data for probability estimates. Ignored when probability is False.<br>Pass an int for reproducible output across multiple function calls | | \* | random_state | \<int\> | None | Controls the pseudo random number generation for shuffling the data for probability estimates. Ignored when probability is False.<br>Pass an int for reproducible output across multiple function calls |
| | max_depth | \<int\> | None | Specifies the maximum depth of the tree | | | max_depth | \<int\> | None | Specifies the maximum depth of the tree |
| \* | tol | \<float\> | 1e-4 | Tolerance for stopping criterion. | | \* | tol | \<float\> | 1e-4 | Tolerance for stopping criterion. |
| \* | degree | \<int\> | 3 | Degree of the polynomial kernel function (poly). Ignored by all other kernels. | | \* | degree | \<int\> | 3 | Degree of the polynomial kernel function (poly). Ignored by all other kernels. |
| \* | gamma | {"scale", "auto"} or \<float\> | scale | Kernel coefficient for rbf, poly and sigmoid.<br>if gamma='scale' (default) is passed then it uses 1 / (n_features \* X.var()) as value of gamma,<br>if auto, uses 1 / n_features. | | \* | gamma | {"scale", "auto"} or \<float\> | scale | Kernel coefficient for rbf, poly and sigmoid.<br>if gamma='scale' (default) is passed then it uses 1 / (n_features \* X.var()) as value of gamma,<br>if auto, uses 1 / n_features. |
| | split_criteria | {"impurity", "max_samples"} | impurity | Decides (just in case of a multi class classification) which column (class) use to split the dataset in a node\*\*. max_samples is incompatible with 'ovo' multiclass_strategy | | | split_criteria | {"impurity", "max_samples"} | impurity | Decides (just in case of a multi class classification) which column (class) use to split the dataset in a node\*\*.<br>max_samples is incompatible with 'ovo' multiclass_strategy |
| | criterion | {“gini”, “entropy”} | entropy | The function to measure the quality of a split (only used if max_features != num_features). <br>Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain. | | | criterion | {“gini”, “entropy”} | entropy | The function to measure the quality of a split (only used if max_features != num_features).<br>Supported criteria are “gini” for the Gini impurity and “entropy” for the information gain. |
| | min_samples_split | \<int\> | 0 | The minimum number of samples required to split an internal node. 0 (default) for any | | | min_samples_split | \<int\> | 0 | The minimum number of samples required to split an internal node. 0 (default) for any |
| | max_features | \<int\>, \<float\> <br><br>or {“auto”, “sqrt”, “log2”} | None | The number of features to consider when looking for the split:<br>If int, then consider max_features features at each split.<br>If float, then max_features is a fraction and int(max_features \* n_features) features are considered at each split.<br>If “auto”, then max_features=sqrt(n_features).<br>If “sqrt”, then max_features=sqrt(n_features).<br>If “log2”, then max_features=log2(n_features).<br>If None, then max_features=n_features. | | | max_features | \<int\>, \<float\> <br><br>or {“auto”, “sqrt”, “log2”} | None | The number of features to consider when looking for the split:<br>If int, then consider max_features features at each split.<br>If float, then max_features is a fraction and int(max_features \* n_features) features are considered at each split.<br>If “auto”, then max_features=sqrt(n_features).<br>If “sqrt”, then max_features=sqrt(n_features).<br>If “log2”, then max_features=log2(n_features).<br>If None, then max_features=n_features. |
| | splitter | {"best", "random", "trandom", "mutual", "cfs", "fcbf", "iwss"} | "random" | The strategy used to choose the feature set at each node (only used if max_features < num_features). Supported strategies are: **best”**: sklearn SelectKBest algorithm is used in every node to choose the max_features best features. **random”**: The algorithm generates 5 candidates and choose the best (max. info. gain) of them. **trandom”**: The algorithm generates only one random combination. **"mutual"**: Chooses the best features w.r.t. their mutual info with the label. **"cfs"**: Apply Correlation-based Feature Selection. **"fcbf"**: Apply Fast Correlation-Based Filter. **"iwss"**: IWSS based algorithm | | | splitter | {"best", "random", "trandom", "mutual", "cfs", "fcbf", "iwss"} | "random" | The strategy used to choose the feature set at each node (only used if max_features < num_features).<br>Supported strategies are:<br>**“best”**: sklearn SelectKBest algorithm is used in every node to choose the max_features best features.<br>**“random”**: The algorithm generates 5 candidates and choose the best (max. info. gain) of them.<br>**“trandom”**: The algorithm generates only one random combination.<br>**"mutual"**: Chooses the best features w.r.t. their mutual info with the label.<br>**"cfs"**: Apply Correlation-based Feature Selection.<br>**"fcbf"**: Apply Fast Correlation-Based Filter.<br>**"iwss"**: IWSS based algorithm |
| | normalize | \<bool\> | False | If standardization of features should be applied on each node with the samples that reach it | | | normalize | \<bool\> | False | If standardization of features should be applied on each node with the samples that reach it |
| \* | multiclass_strategy | {"ovo", "ovr"} | "ovo" | Strategy to use with multiclass datasets, **"ovo"**: one versus one. **"ovr"**: one versus rest | | \* | multiclass_strategy | {"ovo", "ovr"} | "ovo" | Strategy to use with multiclass datasets:<br>**"ovo"**: one versus one.<br>**"ovr"**: one versus rest |
\* Hyperparameter used by the support vector classifier of every node \* Hyperparameter used by the support vector classifier of every node

View File

@@ -7,9 +7,8 @@ def readme():
return f.read() return f.read()
def get_data(field): def get_data(field, file_name="__init__.py"):
item = "" item = ""
file_name = "_version.py" if field == "version" else "__init__.py"
with open(os.path.join("stree", file_name)) as f: with open(os.path.join("stree", file_name)) as f:
for line in f.readlines(): for line in f.readlines():
if line.startswith(f"__{field}__"): if line.startswith(f"__{field}__"):
@@ -21,9 +20,14 @@ def get_data(field):
return item return item
def get_requirements():
with open("requirements.txt") as f:
return f.read().splitlines()
setuptools.setup( setuptools.setup(
name="STree", name="STree",
version=get_data("version"), version=get_data("version", "_version.py"),
license=get_data("license"), license=get_data("license"),
description="Oblique decision tree with svm nodes", description="Oblique decision tree with svm nodes",
long_description=readme(), long_description=readme(),
@@ -46,7 +50,7 @@ setuptools.setup(
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
], ],
install_requires=["scikit-learn", "mufs"], install_requires=get_requirements(),
test_suite="stree.tests", test_suite="stree.tests",
zip_safe=False, zip_safe=False,
) )

View File

@@ -68,6 +68,7 @@ class Snode:
self._impurity = impurity self._impurity = impurity
self._partition_column: int = -1 self._partition_column: int = -1
self._scaler = scaler self._scaler = scaler
self._proba = None
@classmethod @classmethod
def copy(cls, node: "Snode") -> "Snode": def copy(cls, node: "Snode") -> "Snode":
@@ -127,23 +128,22 @@ class Snode:
def get_up(self) -> "Snode": def get_up(self) -> "Snode":
return self._up return self._up
def make_predictor(self): def make_predictor(self, num_classes: int) -> None:
"""Compute the class of the predictor and its belief based on the """Compute the class of the predictor and its belief based on the
subdataset of the node only if it is a leaf subdataset of the node only if it is a leaf
""" """
if not self.is_leaf(): if not self.is_leaf():
return return
classes, card = np.unique(self._y, return_counts=True) classes, card = np.unique(self._y, return_counts=True)
if len(classes) > 1: self._proba = np.zeros((num_classes,), dtype=np.int64)
for c, n in zip(classes, card):
self._proba[c] = n
try:
max_card = max(card) max_card = max(card)
self._class = classes[card == max_card][0] self._class = classes[card == max_card][0]
self._belief = max_card / np.sum(card) self._belief = max_card / np.sum(card)
else: except ValueError:
self._belief = 1 self._class = None
try:
self._class = classes[0]
except IndexError:
self._class = None
def graph(self): def graph(self):
""" """
@@ -155,7 +155,7 @@ class Snode:
output += ( output += (
f'N{id(self)} [shape=box style=filled label="' f'N{id(self)} [shape=box style=filled label="'
f"class={self._class} impurity={self._impurity:.3f} " f"class={self._class} impurity={self._impurity:.3f} "
f'classes={count_values[0]} samples={count_values[1]}"];\n' f'counts={self._proba}"];\n'
) )
else: else:
output += ( output += (

View File

@@ -314,7 +314,7 @@ class Stree(BaseEstimator, ClassifierMixin):
if np.unique(y).shape[0] == 1: if np.unique(y).shape[0] == 1:
# only 1 class => pure dataset # only 1 class => pure dataset
node.set_title(title + ", <pure>") node.set_title(title + ", <pure>")
node.make_predictor() node.make_predictor(self.n_classes_)
return node return node
# Train the model # Train the model
clf = self._build_clf() clf = self._build_clf()
@@ -333,7 +333,7 @@ class Stree(BaseEstimator, ClassifierMixin):
if X_U is None or X_D is None: if X_U is None or X_D is None:
# didn't part anything # didn't part anything
node.set_title(title + ", <cgaf>") node.set_title(title + ", <cgaf>")
node.make_predictor() node.make_predictor(self.n_classes_)
return node return node
node.set_up( node.set_up(
self._train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})") self._train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
@@ -367,28 +367,100 @@ class Stree(BaseEstimator, ClassifierMixin):
) )
) )
@staticmethod def __predict_class(self, X: np.array) -> np.array:
def _reorder_results(y: np.array, indices: np.array) -> np.array: """Compute the predicted class for the samples in X. Returns the number
"""Reorder an array based on the array of indices passed of samples of each class in the corresponding leaf node.
Parameters Parameters
---------- ----------
y : np.array X : np.array
data untidy Array of samples
indices : np.array
indices used to set order
Returns Returns
------- -------
np.array np.array
array y ordered Array of shape (n_samples, n_classes) with the number of samples
of each class in the corresponding leaf node
""" """
# return array of same type given in y
y_ordered = y.copy() def compute_prediction(xp, indices, node):
indices = indices.astype(int) if xp is None:
for i, index in enumerate(indices): return
y_ordered[index] = y[i] if node.is_leaf():
return y_ordered # set a class for indices
result[indices] = node._proba
return
self.splitter_.partition(xp, node, train=False)
x_u, x_d = self.splitter_.part(xp)
i_u, i_d = self.splitter_.part(indices)
compute_prediction(x_u, i_u, node.get_up())
compute_prediction(x_d, i_d, node.get_down())
# setup prediction & make it happen
result = np.zeros((X.shape[0], self.n_classes_))
indices = np.arange(X.shape[0])
compute_prediction(X, indices, self.tree_)
return result
def check_predict(self, X) -> np.array:
"""Checks predict and predict_proba preconditions. If input X is not an
np.array convert it to one.
Parameters
----------
X : np.ndarray
Array of samples
Returns
-------
np.array
Array of samples
Raises
------
ValueError
If number of features of X is different of the number of features
in training data
"""
check_is_fitted(self, ["tree_"])
# Input validation
X = check_array(X)
if X.shape[1] != self.n_features_:
raise ValueError(
f"Expected {self.n_features_} features but got "
f"({X.shape[1]})"
)
return X
def predict_proba(self, X: np.array) -> np.array:
"""Predict class probabilities of the input samples X.
The predicted class probability is the fraction of samples of the same
class in a leaf.
Parameters
----------
X : dataset of samples.
Returns
-------
proba : array of shape (n_samples, n_classes)
The class probabilities of the input samples.
Raises
------
ValueError
if dataset with inconsistent number of features
NotFittedError
if model is not fitted
"""
X = self.check_predict(X)
# return # of samples of each class in leaf node
values = self.__predict_class(X)
normalizer = values.sum(axis=1)[:, np.newaxis]
normalizer[normalizer == 0.0] = 1.0
return values / normalizer
def predict(self, X: np.array) -> np.array: def predict(self, X: np.array) -> np.array:
"""Predict labels for each sample in dataset passed """Predict labels for each sample in dataset passed
@@ -410,40 +482,8 @@ class Stree(BaseEstimator, ClassifierMixin):
NotFittedError NotFittedError
if model is not fitted if model is not fitted
""" """
X = self.check_predict(X)
def predict_class( return self.classes_[np.argmax(self.__predict_class(X), axis=1)]
xp: np.array, indices: np.array, node: Snode
) -> np.array:
if xp is None:
return [], []
if node.is_leaf():
# set a class for every sample in dataset
prediction = np.full((xp.shape[0], 1), node._class)
return prediction, indices
self.splitter_.partition(xp, node, train=False)
x_u, x_d = self.splitter_.part(xp)
i_u, i_d = self.splitter_.part(indices)
prx_u, prin_u = predict_class(x_u, i_u, node.get_up())
prx_d, prin_d = predict_class(x_d, i_d, node.get_down())
return np.append(prx_u, prx_d), np.append(prin_u, prin_d)
# sklearn check
check_is_fitted(self, ["tree_"])
# Input validation
X = check_array(X)
if X.shape[1] != self.n_features_:
raise ValueError(
f"Expected {self.n_features_} features but got "
f"({X.shape[1]})"
)
# setup prediction & make it happen
indices = np.arange(X.shape[0])
result = (
self._reorder_results(*predict_class(X, indices, self.tree_))
.astype(int)
.ravel()
)
return self.classes_[result]
def nodes_leaves(self) -> tuple: def nodes_leaves(self) -> tuple:
"""Compute the number of nodes and leaves in the built tree """Compute the number of nodes and leaves in the built tree

View File

@@ -1 +1 @@
__version__ = "1.2.4" __version__ = "1.3.0"

View File

@@ -67,10 +67,28 @@ class Snode_test(unittest.TestCase):
def test_make_predictor_on_leaf(self): def test_make_predictor_on_leaf(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test") test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
test.make_predictor() test.make_predictor(2)
self.assertEqual(1, test._class) self.assertEqual(1, test._class)
self.assertEqual(0.75, test._belief) self.assertEqual(0.75, test._belief)
self.assertEqual(-1, test._partition_column) self.assertEqual(-1, test._partition_column)
self.assertListEqual([1, 3], test._proba.tolist())
def test_make_predictor_on_not_leaf(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
test.set_up(Snode(None, [1], [1], [], 0.0, "another_test"))
test.make_predictor(2)
self.assertIsNone(test._class)
self.assertEqual(0, test._belief)
self.assertEqual(-1, test._partition_column)
self.assertEqual(-1, test.get_up()._partition_column)
self.assertIsNone(test._proba)
def test_make_predictor_on_leaf_bogus_data(self):
test = Snode(None, [1, 2, 3, 4], [], [], 0.0, "test")
test.make_predictor(2)
self.assertIsNone(test._class)
self.assertEqual(-1, test._partition_column)
self.assertListEqual([0, 0], test._proba.tolist())
def test_set_title(self): def test_set_title(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test") test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
@@ -97,21 +115,6 @@ class Snode_test(unittest.TestCase):
test.set_features([1, 2]) test.set_features([1, 2])
self.assertListEqual([1, 2], test.get_features()) self.assertListEqual([1, 2], test.get_features())
def test_make_predictor_on_not_leaf(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
test.set_up(Snode(None, [1], [1], [], 0.0, "another_test"))
test.make_predictor()
self.assertIsNone(test._class)
self.assertEqual(0, test._belief)
self.assertEqual(-1, test._partition_column)
self.assertEqual(-1, test.get_up()._partition_column)
def test_make_predictor_on_leaf_bogus_data(self):
test = Snode(None, [1, 2, 3, 4], [], [], 0.0, "test")
test.make_predictor()
self.assertIsNone(test._class)
self.assertEqual(-1, test._partition_column)
def test_copy_node(self): def test_copy_node(self):
px = [1, 2, 3, 4] px = [1, 2, 3, 4]
py = [1] py = [1]

View File

@@ -115,6 +115,38 @@ class Stree_test(unittest.TestCase):
yp = clf.fit(X, y).predict(X[:num, :]) yp = clf.fit(X, y).predict(X[:num, :])
self.assertListEqual(y[:num].tolist(), yp.tolist()) self.assertListEqual(y[:num].tolist(), yp.tolist())
def test_multiple_predict_proba(self):
expected = {
"liblinear": {
0: [0.02401129943502825, 0.9759887005649718],
17: [0.9282970550576184, 0.07170294494238157],
},
"linear": {
0: [0.029329608938547486, 0.9706703910614525],
17: [0.9298469387755102, 0.07015306122448979],
},
"rbf": {
0: [0.023448275862068966, 0.976551724137931],
17: [0.9458064516129032, 0.05419354838709677],
},
"poly": {
0: [0.01601164483260553, 0.9839883551673945],
17: [0.9089790897908979, 0.0910209102091021],
},
}
indices = [0, 17]
X, y = load_dataset(self._random_state)
for kernel in ["liblinear", "linear", "rbf", "poly"]:
clf = Stree(
kernel=kernel,
multiclass_strategy="ovr" if kernel == "liblinear" else "ovo",
random_state=self._random_state,
)
yp = clf.fit(X, y).predict_proba(X)
for index in indices:
for exp, comp in zip(expected[kernel][index], yp[index]):
self.assertAlmostEqual(exp, comp)
def test_single_vs_multiple_prediction(self): def test_single_vs_multiple_prediction(self):
"""Check if predicting sample by sample gives the same result as """Check if predicting sample by sample gives the same result as
predicting all samples at once predicting all samples at once
@@ -695,7 +727,7 @@ class Stree_test(unittest.TestCase):
) )
expected_tail = ( expected_tail = (
' [shape=box style=filled label="class=1 impurity=0.000 ' ' [shape=box style=filled label="class=1 impurity=0.000 '
'classes=[1] samples=[1]"];\n}\n' 'counts=[0 1 0]"];\n}\n'
) )
self.assertEqual(clf.graph(), expected_head + "}\n") self.assertEqual(clf.graph(), expected_head + "}\n")
clf.fit(X, y) clf.fit(X, y)
@@ -715,7 +747,7 @@ class Stree_test(unittest.TestCase):
) )
expected_tail = ( expected_tail = (
' [shape=box style=filled label="class=1 impurity=0.000 ' ' [shape=box style=filled label="class=1 impurity=0.000 '
'classes=[1] samples=[1]"];\n}\n' 'counts=[0 1 0]"];\n}\n'
) )
self.assertEqual(clf.graph("Sample title"), expected_head + "}\n") self.assertEqual(clf.graph("Sample title"), expected_head + "}\n")
clf.fit(X, y) clf.fit(X, y)