mirror of
https://github.com/Doctorado-ML/STree.git
synced 2025-08-15 15:36:00 +00:00
#3 First try, change LinearSVC to SVC
make a builder start changing tests
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
|
||||
# Stree
|
||||
|
||||
Oblique Tree classifier based on SVM nodes. The nodes are built and splitted with sklearn LinearSVC models.Stree is a sklearn estimator and can be integrated in pipelines, grid searches, etc.
|
||||
Oblique Tree classifier based on SVM nodes. The nodes are built and splitted with sklearn SVC models.Stree is a sklearn estimator and can be integrated in pipelines, grid searches, etc.
|
||||
|
||||

|
||||
|
||||
|
@@ -9,7 +9,7 @@
|
||||
"import time\n",
|
||||
"from sklearn.ensemble import AdaBoostClassifier\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||
"from sklearn.svm import LinearSVC\n",
|
||||
"from sklearn.svm import SVC\n",
|
||||
"from sklearn.model_selection import GridSearchCV, train_test_split\n",
|
||||
"from sklearn.datasets import load_iris\n",
|
||||
"from stree import Stree"
|
||||
@@ -131,7 +131,7 @@
|
||||
],
|
||||
"source": [
|
||||
"now = time.time()\n",
|
||||
"clf3 = AdaBoostClassifier(LinearSVC(random_state=random_state), n_estimators=100, random_state=random_state, algorithm='SAMME')\n",
|
||||
"clf3 = AdaBoostClassifier(SVC(kernel="linear",random_state=random_state), n_estimators=100, random_state=random_state, algorithm='SAMME')\n",
|
||||
"clf3.fit(Xtrain, ytrain)\n",
|
||||
"print(\"Score Train: \", clf3.score(Xtrain, ytrain))\n",
|
||||
"print(\"Score Test: \", clf3.score(Xtest, ytest))\n",
|
||||
|
@@ -20,7 +20,7 @@
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.svm import LinearSVC\n",
|
||||
"from sklearn.svm import SVC\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from sklearn.datasets import make_classification, load_iris, load_wine\n",
|
||||
|
@@ -8,7 +8,7 @@
|
||||
"source": [
|
||||
"from sklearn.ensemble import AdaBoostClassifier\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||
"from sklearn.svm import LinearSVC\n",
|
||||
"from sklearn.svm import SVC\n",
|
||||
"from sklearn.model_selection import GridSearchCV, train_test_split\n",
|
||||
"from sklearn.datasets import load_iris\n",
|
||||
"from stree import Stree"
|
||||
@@ -109,7 +109,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"LinearSVC().get_params(), DecisionTreeClassifier().get_params()"
|
||||
"SVC(kernel="linear",).get_params(), DecisionTreeClassifier().get_params()"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@@ -14,13 +14,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"from sklearn.svm import LinearSVC\n",
|
||||
"from sklearn.svm import SVC\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||
"from sklearn.datasets import make_classification, load_iris, load_wine\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
@@ -42,7 +42,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@@ -96,13 +96,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "Accuracy of Train without weights 0.996415770609319\nAccuracy of Train with weights 0.994026284348865\nAccuracy of Tests without weights 0.9665738161559888\nAccuracy of Tests with weights 0.9721448467966574\n"
|
||||
"text": "Accuracy of Train without weights 1.0\nAccuracy of Train with weights 1.0\nAccuracy of Tests without weights 0.9554317548746518\nAccuracy of Tests with weights 0.9777158774373259\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
@@ -115,13 +115,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"tags": [
|
||||
"outputPrepend"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9749\nClassifier's accuracy (test) : 0.9749\nroot\nroot - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([117]))\nroot - Up, <cgaf> - Leaf class=0 belief= 0.970833 counts=(array([0, 1]), array([699, 21]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9797\nClassifier's accuracy (test) : 0.9777\nroot\nroot - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([121]))\nroot - Up, <cgaf> - Leaf class=0 belief= 0.976257 counts=(array([0, 1]), array([699, 17]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9869\nClassifier's accuracy (test) : 0.9805\nroot\nroot - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([127]))\nroot - Up, <cgaf> - Leaf class=0 belief= 0.984507 counts=(array([0, 1]), array([699, 11]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9892\nClassifier's accuracy (test) : 0.9721\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([128]))\nroot - Down - Up, <pure> - Leaf class=0 belief= 1.000000 counts=(array([0]), array([3]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief= 1.000000 counts=(array([0]), array([1]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief= 0.987216 counts=(array([0, 1]), array([695, 9]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9952\nClassifier's accuracy (test) : 0.9749\nroot\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([120]))\nroot - Down - Up, <pure> - Leaf class=0 belief= 1.000000 counts=(array([0]), array([57]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([14]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief= 1.000000 counts=(array([0]), array([12]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief= 0.993691 counts=(array([0, 1]), array([630, 4]))\n\n**************************************************\n0.1084 secs\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"t = time.time()\n",
|
||||
"for C in (.001, .01, 1, 5, 17):\n",
|
||||
@@ -137,9 +143,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "[[0.95399748 0.04600252]\n [0.92625258 0.07374742]\n [0.97804877 0.02195123]\n [0.94803313 0.05196687]]\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"from sklearn.preprocessing import StandardScaler\n",
|
||||
@@ -154,9 +166,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([120]))\nroot - Down - Up, <pure> - Leaf class=0 belief= 1.000000 counts=(array([0]), array([57]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([14]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief= 1.000000 counts=(array([0]), array([12]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief= 0.993691 counts=(array([0, 1]), array([630, 4]))\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#check iterator\n",
|
||||
"for i in list(clf):\n",
|
||||
@@ -165,9 +183,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "root\nroot - Down\nroot - Down - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([120]))\nroot - Down - Up, <pure> - Leaf class=0 belief= 1.000000 counts=(array([0]), array([57]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, <pure> - Leaf class=1 belief= 1.000000 counts=(array([1]), array([14]))\nroot - Up - Down - Up, <pure> - Leaf class=0 belief= 1.000000 counts=(array([0]), array([12]))\nroot - Up - Up, <cgaf> - Leaf class=0 belief= 0.993691 counts=(array([0, 1]), array([630, 4]))\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#check iterator again\n",
|
||||
"for i in clf:\n",
|
||||
@@ -176,7 +200,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -187,9 +211,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "stream",
|
||||
"name": "stdout",
|
||||
"text": "1 functools.partial(<function check_no_attributes_set_in_init at 0x128922b90>, 'Stree')\n2 functools.partial(<function check_estimators_dtypes at 0x128918cb0>, 'Stree')\n3 functools.partial(<function check_fit_score_takes_y at 0x128918b90>, 'Stree')\n4 functools.partial(<function check_sample_weights_pandas_series at 0x1289144d0>, 'Stree')\n5 functools.partial(<function check_sample_weights_not_an_array at 0x1289145f0>, 'Stree')\n6 functools.partial(<function check_sample_weights_list at 0x128914710>, 'Stree')\n7 functools.partial(<function check_sample_weights_shape at 0x128914830>, 'Stree')\n8 functools.partial(<function check_sample_weights_invariance at 0x128914950>, 'Stree')\n9 functools.partial(<function check_estimators_fit_returns_self at 0x12891ecb0>, 'Stree')\n10 functools.partial(<function check_estimators_fit_returns_self at 0x12891ecb0>, 'Stree', readonly_memmap=True)\n11 functools.partial(<function check_complex_data at 0x128914b00>, 'Stree')\n12 functools.partial(<function check_dtype_object at 0x128914a70>, 'Stree')\n13 functools.partial(<function check_estimators_empty_data_messages at 0x128918dd0>, 'Stree')\n14 functools.partial(<function check_pipeline_consistency at 0x128918a70>, 'Stree')\n15 functools.partial(<function check_estimators_nan_inf at 0x128918ef0>, 'Stree')\n16 functools.partial(<function check_estimators_overwrite_params at 0x128922a70>, 'Stree')\n17 functools.partial(<function check_estimator_sparse_data at 0x1289143b0>, 'Stree')\n18 functools.partial(<function check_estimators_pickle at 0x12891e170>, 'Stree')\n19 functools.partial(<function check_classifier_data_not_an_array at 0x128922dd0>, 'Stree')\n20 functools.partial(<function check_classifiers_one_label at 0x12891e830>, 'Stree')\n21 functools.partial(<function check_classifiers_classes at 0x128922290>, 'Stree')\n22 functools.partial(<function check_estimators_partial_fit_n_features at 0x12891e290>, 'Stree')\n23 functools.partial(<function check_classifiers_train at 0x12891e950>, 'Stree')\n24 functools.partial(<function check_classifiers_train at 0x12891e950>, 'Stree', readonly_memmap=True)\n25 functools.partial(<function check_classifiers_train at 0x12891e950>, 'Stree', readonly_memmap=True, X_dtype='float32')\n26 functools.partial(<function check_classifiers_regression_target at 0x1289278c0>, 'Stree')\n27 functools.partial(<function check_supervised_y_no_nan at 0x12890c4d0>, 'Stree')\n28 functools.partial(<function check_supervised_y_2d at 0x12891eef0>, 'Stree')\n29 functools.partial(<function check_estimators_unfitted at 0x12891edd0>, 'Stree')\n30 functools.partial(<function check_non_transformer_estimators_n_iter at 0x128927440>, 'Stree')\n31 functools.partial(<function check_decision_proba_consistency at 0x1289279e0>, 'Stree')\n32 functools.partial(<function check_fit2d_predict1d at 0x128918050>, 'Stree')\n33 functools.partial(<function check_methods_subset_invariance at 0x128918200>, 'Stree')\n34 functools.partial(<function check_fit2d_1sample at 0x128918320>, 'Stree')\n35 functools.partial(<function check_fit2d_1feature at 0x128918440>, 'Stree')\n36 functools.partial(<function check_fit1d at 0x128918560>, 'Stree')\n37 functools.partial(<function check_get_params_invariance at 0x128927680>, 'Stree')\n38 functools.partial(<function check_set_params at 0x1289277a0>, 'Stree')\n39 functools.partial(<function check_dict_unchanged at 0x128914c20>, 'Stree')\n40 functools.partial(<function check_dont_overwrite_parameters at 0x128914ef0>, 'Stree')\n41 functools.partial(<function check_fit_idempotent at 0x128927b90>, 'Stree')\n42 functools.partial(<function check_n_features_in at 0x128927c20>, 'Stree')\n43 functools.partial(<function check_requires_y_none at 0x128927cb0>, 'Stree')\n"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Make checks one by one\n",
|
||||
"c = 0\n",
|
||||
@@ -199,6 +229,13 @@
|
||||
" print(c, check[1])\n",
|
||||
" check[1](check[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@@ -34,7 +34,7 @@
|
||||
"import random\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.datasets import make_blobs\n",
|
||||
"from sklearn.svm import LinearSVC\n",
|
||||
"from sklearn.svm import SVC\n",
|
||||
"from stree import Stree, Stree_grapher"
|
||||
]
|
||||
},
|
||||
|
@@ -4,14 +4,13 @@ __copyright__ = "Copyright 2020, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT"
|
||||
__version__ = "0.9"
|
||||
Build an oblique tree classifier based on SVM Trees
|
||||
Uses LinearSVC
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator, ClassifierMixin
|
||||
from sklearn.svm import LinearSVC
|
||||
from sklearn.svm import SVC, LinearSVC
|
||||
from sklearn.utils.multiclass import check_classification_targets
|
||||
from sklearn.utils.validation import (
|
||||
check_X_y,
|
||||
@@ -26,12 +25,8 @@ class Snode:
|
||||
dataset assigned to it
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, clf: LinearSVC, X: np.ndarray, y: np.ndarray, title: str
|
||||
):
|
||||
def __init__(self, clf: SVC, X: np.ndarray, y: np.ndarray, title: str):
|
||||
self._clf = clf
|
||||
self._vector = None if clf is None else clf.coef_
|
||||
self._interceptor = 0.0 if clf is None else clf.intercept_
|
||||
self._title = title
|
||||
self._belief = 0.0
|
||||
# Only store dataset in Testing
|
||||
@@ -126,6 +121,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
def __init__(
|
||||
self,
|
||||
C: float = 1.0,
|
||||
kernel: str = "linear",
|
||||
max_iter: int = 1000,
|
||||
random_state: int = None,
|
||||
max_depth: int = None,
|
||||
@@ -135,6 +131,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
):
|
||||
self.max_iter = max_iter
|
||||
self.C = C
|
||||
self.kernel = kernel
|
||||
self.random_state = random_state
|
||||
self.use_predictions = use_predictions
|
||||
self.max_depth = max_depth
|
||||
@@ -161,8 +158,8 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
:return: array of distances of each sample to the hyperplane
|
||||
:rtype: np.array
|
||||
"""
|
||||
coef = node._vector[0, :].reshape(-1, data.shape[1])
|
||||
return data.dot(coef.T) + node._interceptor[0]
|
||||
coef = node._clf.coef_[0, :].reshape(-1, data.shape[1])
|
||||
return data.dot(coef.T) + node._clf.intercept_[0]
|
||||
|
||||
def _split_array(self, origin: np.array, down: np.array) -> list:
|
||||
"""Split an array in two based on indices passed as down and its complement
|
||||
@@ -266,6 +263,26 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
|
||||
run_tree(self.tree_)
|
||||
|
||||
def _build_clf(self):
|
||||
""" Select the correct classifier for the node
|
||||
"""
|
||||
|
||||
return (
|
||||
LinearSVC(
|
||||
max_iter=self.max_iter,
|
||||
random_state=self.random_state,
|
||||
C=self.C,
|
||||
tol=self.tol,
|
||||
)
|
||||
if self.kernel == "linear"
|
||||
else SVC(
|
||||
kernel=self.kernel,
|
||||
max_iter=self.max_iter,
|
||||
tol=self.tol,
|
||||
C=self.C,
|
||||
)
|
||||
)
|
||||
|
||||
def train(
|
||||
self,
|
||||
X: np.ndarray,
|
||||
@@ -296,9 +313,7 @@ class Stree(BaseEstimator, ClassifierMixin):
|
||||
# only 1 class => pure dataset
|
||||
return Snode(None, X, y, title + ", <pure>")
|
||||
# Train the model
|
||||
clf = LinearSVC(
|
||||
max_iter=self.max_iter, random_state=self.random_state, C=self.C
|
||||
) # , sample_weight=sample_weight)
|
||||
clf = self._build_clf()
|
||||
clf.fit(X, y, sample_weight=sample_weight)
|
||||
tree = Snode(clf, X, y, title)
|
||||
self.depth_ = max(depth, self.depth_)
|
||||
|
@@ -73,10 +73,10 @@ class Snode_graph(Snode):
|
||||
# get the splitting hyperplane
|
||||
def hyperplane(x, y):
|
||||
return (
|
||||
-self._interceptor
|
||||
- self._vector[0][0] * x
|
||||
- self._vector[0][1] * y
|
||||
) / self._vector[0][2]
|
||||
-self._clf.intercept_
|
||||
- self._clf.coef_[0][0] * x
|
||||
- self._clf.coef_[0][1] * y
|
||||
) / self._clf.coef_[0][2]
|
||||
|
||||
tmpx = np.linspace(self._X[:, 0].min(), self._X[:, 0].max())
|
||||
tmpy = np.linspace(self._X[:, 1].min(), self._X[:, 1].max())
|
||||
|
@@ -76,7 +76,9 @@ class Stree_grapher_test(unittest.TestCase):
|
||||
|
||||
def test_save_all(self):
|
||||
folder_name = "/tmp/"
|
||||
file_names = [f"{folder_name}STnode{i}.png" for i in range(1, 8)]
|
||||
file_names = [
|
||||
os.path.join(folder_name, f"STnode{i}.png") for i in range(1, 8)
|
||||
]
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
matplotlib.use("Agg")
|
||||
@@ -160,8 +162,6 @@ class Snode_graph_test(unittest.TestCase):
|
||||
# only exclude pure leaves
|
||||
self.assertIsNotNone(node._clf)
|
||||
self.assertIsNotNone(node._clf.coef_)
|
||||
self.assertIsNotNone(node._vector)
|
||||
self.assertIsNotNone(node._interceptor)
|
||||
if node.is_leaf():
|
||||
return
|
||||
run_tree(node.get_down())
|
||||
@@ -171,7 +171,7 @@ class Snode_graph_test(unittest.TestCase):
|
||||
|
||||
def test_save_hyperplane(self):
|
||||
folder_name = "/tmp/"
|
||||
file_name = f"{folder_name}STnode1.png"
|
||||
file_name = os.path.join(folder_name, "STnode1.png")
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
matplotlib.use("Agg")
|
||||
|
@@ -84,22 +84,6 @@ class Stree_test(unittest.TestCase):
|
||||
"""
|
||||
self._check_tree(self._clf.tree_)
|
||||
|
||||
def _get_file_data(self, file_name: str) -> tuple:
|
||||
"""Return X, y from data, y is the last column in array
|
||||
|
||||
Arguments:
|
||||
file_name {str} -- the file name
|
||||
|
||||
Returns:
|
||||
tuple -- tuple with samples, categories
|
||||
"""
|
||||
data = np.genfromtxt(file_name, delimiter=",")
|
||||
data = np.array(data)
|
||||
column_y = data.shape[1] - 1
|
||||
fy = data[:, column_y]
|
||||
fx = np.delete(data, column_y, axis=1)
|
||||
return fx, fy
|
||||
|
||||
def _find_out(
|
||||
self, px: np.array, x_original: np.array, y_original
|
||||
) -> list:
|
||||
@@ -134,11 +118,18 @@ class Stree_test(unittest.TestCase):
|
||||
|
||||
def test_score(self):
|
||||
X, y = get_dataset(self._random_state)
|
||||
accuracy_score = self._clf.score(X, y)
|
||||
yp = self._clf.predict(X)
|
||||
accuracy_computed = np.mean(yp == y)
|
||||
self.assertEqual(accuracy_score, accuracy_computed)
|
||||
self.assertGreater(accuracy_score, 0.9)
|
||||
for kernel in ["linear"]:
|
||||
clf = Stree(
|
||||
random_state=self._random_state,
|
||||
kernel=kernel,
|
||||
use_predictions=True,
|
||||
)
|
||||
clf.fit(X, y)
|
||||
accuracy_score = clf.score(X, y)
|
||||
yp = clf.predict(X)
|
||||
accuracy_computed = np.mean(yp == y)
|
||||
self.assertEqual(accuracy_score, accuracy_computed)
|
||||
self.assertGreater(accuracy_score, 0.9)
|
||||
|
||||
def test_single_predict_proba(self):
|
||||
"""Check that element 28 has a prediction different that the current
|
||||
@@ -306,10 +297,11 @@ class Stree_test(unittest.TestCase):
|
||||
tcl.fit(*get_dataset(self._random_state))
|
||||
|
||||
def test_check_max_depth(self):
|
||||
depth = 3
|
||||
tcl = Stree(random_state=self._random_state, max_depth=depth)
|
||||
tcl.fit(*get_dataset(self._random_state))
|
||||
self.assertEqual(depth, tcl.depth_)
|
||||
depths = (3, 4)
|
||||
for depth in depths:
|
||||
tcl = Stree(random_state=self._random_state, max_depth=depth)
|
||||
tcl.fit(*get_dataset(self._random_state))
|
||||
self.assertEqual(depth, tcl.depth_)
|
||||
|
||||
def test_unfitted_tree_is_iterable(self):
|
||||
tcl = Stree()
|
||||
@@ -383,8 +375,6 @@ class Snode_test(unittest.TestCase):
|
||||
# only exclude pure leaves
|
||||
self.assertIsNotNone(node._clf)
|
||||
self.assertIsNotNone(node._clf.coef_)
|
||||
self.assertIsNotNone(node._vector)
|
||||
self.assertIsNotNone(node._interceptor)
|
||||
if node.is_leaf():
|
||||
return
|
||||
run_tree(node.get_down())
|
||||
|
Reference in New Issue
Block a user