mirror of
https://github.com/Doctorado-ML/Odte.git
synced 2025-07-11 16:22:00 +00:00
Add initialization parameters
Add first notebook
This commit is contained in:
parent
50e25bc372
commit
278a5ff28a
414
notebooks/benchmark.ipynb
Normal file
414
notebooks/benchmark.ipynb
Normal file
@ -0,0 +1,414 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Compare STree with different estimators"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Setup\n",
|
||||||
|
"Uncomment the next cell if STree is not already installed"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"#\n",
|
||||||
|
"# Google Colab setup\n",
|
||||||
|
"#\n",
|
||||||
|
"#!pip install git+https://github.com/doctorado-ml/stree"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import datetime, time\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"from sklearn.model_selection import train_test_split\n",
|
||||||
|
"from sklearn import tree\n",
|
||||||
|
"from sklearn.metrics import classification_report, confusion_matrix, f1_score\n",
|
||||||
|
"from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier\n",
|
||||||
|
"from stree import Stree\n",
|
||||||
|
"from odte import"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"if not os.path.isfile('data/creditcard.csv'):\n",
|
||||||
|
" !wget --no-check-certificate --content-disposition http://nube.jccm.es/index.php/s/Zs7SYtZQJ3RQ2H2/download\n",
|
||||||
|
" !tar xzf creditcard.tgz"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Tests"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"2020-06-15 00:46:56\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(datetime.date.today(), time.strftime(\"%H:%M:%S\"))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Load dataset and normalize values"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load Dataset\n",
|
||||||
|
"df = pd.read_csv('data/creditcard.csv')\n",
|
||||||
|
"df.shape\n",
|
||||||
|
"random_state = 2020"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Fraud: 0.173% 492\n",
|
||||||
|
"Valid: 99.827% 284,315\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(\"Fraud: {0:.3f}% {1}\".format(df.Class[df.Class == 1].count()*100/df.shape[0], df.Class[df.Class == 1].count()))\n",
|
||||||
|
"print(\"Valid: {0:.3f}% {1:,}\".format(df.Class[df.Class == 0].count()*100/df.shape[0], df.Class[df.Class == 0].count()))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Normalize Amount\n",
|
||||||
|
"from sklearn.preprocessing import RobustScaler\n",
|
||||||
|
"values = RobustScaler().fit_transform(df.Amount.values.reshape(-1, 1))\n",
|
||||||
|
"df['Amount_Scaled'] = values"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"X shape: (284807, 29)\n",
|
||||||
|
"y shape: (284807,)\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Remove unneeded features\n",
|
||||||
|
"y = df.Class.values\n",
|
||||||
|
"X = df.drop(['Class', 'Time', 'Amount'], axis=1).values\n",
|
||||||
|
"print(f\"X shape: {X.shape}\\ny shape: {y.shape}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Build the models"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Divide dataset\n",
|
||||||
|
"train_size = .7\n",
|
||||||
|
"Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, train_size=train_size, shuffle=True, random_state=random_state, stratify=y)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Linear Tree\n",
|
||||||
|
"linear_tree = tree.DecisionTreeClassifier(random_state=random_state)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Random Forest\n",
|
||||||
|
"random_forest = RandomForestClassifier(random_state=random_state)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Stree\n",
|
||||||
|
"stree = Stree(random_state=random_state, C=.01)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# AdaBoost\n",
|
||||||
|
"adaboost = AdaBoostClassifier(random_state=random_state)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Gradient Boosting\n",
|
||||||
|
"gradient = GradientBoostingClassifier(random_state=random_state)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "NameError",
|
||||||
|
"evalue": "name 'Odte' is not defined",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"\u001b[0;32m<ipython-input-15-98265fce1448>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0modte\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mOdte\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrandom_state\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrandom_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||||
|
"\u001b[0;31mNameError\u001b[0m: name 'Odte' is not defined"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"odte = Odte(random_state=random_state)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Do the test"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def try_model(name, model):\n",
|
||||||
|
" print(f\"************************** {name} **********************\")\n",
|
||||||
|
" now = time.time()\n",
|
||||||
|
" model.fit(Xtrain, ytrain)\n",
|
||||||
|
" spent = time.time() - now\n",
|
||||||
|
" print(f\"Train Model {name} took: {spent:.4} seconds\")\n",
|
||||||
|
" predict = model.predict(Xtrain)\n",
|
||||||
|
" predictt = model.predict(Xtest)\n",
|
||||||
|
" print(f\"=========== {name} - Train {Xtrain.shape[0]:,} samples =============\",)\n",
|
||||||
|
" print(classification_report(ytrain, predict, digits=6))\n",
|
||||||
|
" print(f\"=========== {name} - Test {Xtest.shape[0]:,} samples =============\")\n",
|
||||||
|
" print(classification_report(ytest, predictt, digits=6))\n",
|
||||||
|
" print(\"Confusion Matrix in Train\")\n",
|
||||||
|
" print(confusion_matrix(ytrain, predict))\n",
|
||||||
|
" print(\"Confusion Matrix in Test\")\n",
|
||||||
|
" print(confusion_matrix(ytest, predictt))\n",
|
||||||
|
" return f1_score(ytest, predictt), spent"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Train & Test models\n",
|
||||||
|
"models = {\n",
|
||||||
|
" 'Linear Tree':linear_tree, 'Random Forest': random_forest, 'Stree (SVM Tree)': stree, \n",
|
||||||
|
" 'AdaBoost model': adaboost, 'Odte': odte #'Gradient Boost.': gradient\n",
|
||||||
|
"}\n",
|
||||||
|
"\n",
|
||||||
|
"best_f1 = 0\n",
|
||||||
|
"outcomes = []\n",
|
||||||
|
"for name, model in models.items():\n",
|
||||||
|
" f1, time_spent = try_model(name, model)\n",
|
||||||
|
" outcomes.append((name, f1, time_spent))\n",
|
||||||
|
" if f1 > best_f1:\n",
|
||||||
|
" best_model = name\n",
|
||||||
|
" best_time = time_spent\n",
|
||||||
|
" best_f1 = f1"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(\"*\"*110)\n",
|
||||||
|
"print(f\"*The best f1 model is {best_model}, with a f1 score: {best_f1:.4} in {best_time:.6} seconds with {train_size:,} samples in train dataset\")\n",
|
||||||
|
"print(\"*\"*110)\n",
|
||||||
|
"for name, f1, time_spent in outcomes:\n",
|
||||||
|
" print(f\"Model: {name}\\t Time: {time_spent:6.2f} seconds\\t f1: {f1:.4}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "raw",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**************************************************************************************************************\n",
|
||||||
|
"*The best f1 model is Random Forest, with a f1 score: 0.8815 in 152.54 seconds with 0.7 samples in train dataset\n",
|
||||||
|
"**************************************************************************************************************\n",
|
||||||
|
"Model: Linear Tree\t Time: 13.52 seconds\t f1: 0.7645\n",
|
||||||
|
"Model: Random Forest\t Time: 152.54 seconds\t f1: 0.8815\n",
|
||||||
|
"Model: Stree (SVM Tree)\t Time: 32.55 seconds\t f1: 0.8603\n",
|
||||||
|
"Model: AdaBoost model\t Time: 47.34 seconds\t f1: 0.7509\n",
|
||||||
|
"Model: Gradient Boost.\t Time: 244.12 seconds\t f1: 0.5259"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"```\n",
|
||||||
|
"******************************************************************************************************************\n",
|
||||||
|
"*The best f1 model is Random Forest, with a f1 score: 0.8815 in 218.966 seconds with 0.7 samples in train dataset\n",
|
||||||
|
"******************************************************************************************************************\n",
|
||||||
|
"Model: Linear Tree Time: 23.05 seconds\t f1: 0.7645\n",
|
||||||
|
"Model: Random Forest\t Time: 218.97 seconds\t f1: 0.8815\n",
|
||||||
|
"Model: Stree (SVM Tree)\t Time: 49.45 seconds\t f1: 0.8467\n",
|
||||||
|
"Model: AdaBoost model\t Time: 73.83 seconds\t f1: 0.7509\n",
|
||||||
|
"Model: Gradient Boost.\t Time: 388.69 seconds\t f1: 0.5259\n",
|
||||||
|
"Model: Neural Network\t Time: 25.47 seconds\t f1: 0.8328\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"hide_input": false,
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.7.6"
|
||||||
|
},
|
||||||
|
"toc": {
|
||||||
|
"base_numbering": 1,
|
||||||
|
"nav_menu": {},
|
||||||
|
"number_sections": true,
|
||||||
|
"sideBar": true,
|
||||||
|
"skip_h1_title": false,
|
||||||
|
"title_cell": "Table of Contents",
|
||||||
|
"title_sidebar": "Contents",
|
||||||
|
"toc_cell": false,
|
||||||
|
"toc_position": {},
|
||||||
|
"toc_section_display": true,
|
||||||
|
"toc_window_display": false
|
||||||
|
},
|
||||||
|
"varInspector": {
|
||||||
|
"cols": {
|
||||||
|
"lenName": 16,
|
||||||
|
"lenType": 16,
|
||||||
|
"lenVar": 40
|
||||||
|
},
|
||||||
|
"kernels_config": {
|
||||||
|
"python": {
|
||||||
|
"delete_cmd_postfix": "",
|
||||||
|
"delete_cmd_prefix": "del ",
|
||||||
|
"library": "var_list.py",
|
||||||
|
"varRefreshCmd": "print(var_dic_list())"
|
||||||
|
},
|
||||||
|
"r": {
|
||||||
|
"delete_cmd_postfix": ") ",
|
||||||
|
"delete_cmd_prefix": "rm(",
|
||||||
|
"library": "var_list.r",
|
||||||
|
"varRefreshCmd": "cat(var_dic_list()) "
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"height": "392px",
|
||||||
|
"left": "1518px",
|
||||||
|
"right": "20px",
|
||||||
|
"top": "40px",
|
||||||
|
"width": "392px"
|
||||||
|
},
|
||||||
|
"types_to_exclude": [
|
||||||
|
"module",
|
||||||
|
"function",
|
||||||
|
"builtin_function_or_method",
|
||||||
|
"instance",
|
||||||
|
"_Feature"
|
||||||
|
],
|
||||||
|
"window_display": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
@ -34,12 +34,14 @@ class Odte(BaseEstimator, ClassifierMixin):
|
|||||||
min_samples_split: int = 0,
|
min_samples_split: int = 0,
|
||||||
bootstrap: bool = True,
|
bootstrap: bool = True,
|
||||||
split_criteria: str = "min_distance",
|
split_criteria: str = "min_distance",
|
||||||
|
criterion: str = "gini",
|
||||||
tol: float = 1e-4,
|
tol: float = 1e-4,
|
||||||
gamma="scale",
|
gamma="scale",
|
||||||
degree: int = 3,
|
degree: int = 3,
|
||||||
kernel: str = "linear",
|
kernel: str = "linear",
|
||||||
max_features="auto",
|
max_features="auto",
|
||||||
max_samples=None,
|
max_samples=None,
|
||||||
|
splitter: str = "random",
|
||||||
):
|
):
|
||||||
self.n_estimators = n_estimators
|
self.n_estimators = n_estimators
|
||||||
self.bootstrap = bootstrap
|
self.bootstrap = bootstrap
|
||||||
@ -52,11 +54,14 @@ class Odte(BaseEstimator, ClassifierMixin):
|
|||||||
min_samples_split=min_samples_split,
|
min_samples_split=min_samples_split,
|
||||||
max_depth=max_depth,
|
max_depth=max_depth,
|
||||||
split_criteria=split_criteria,
|
split_criteria=split_criteria,
|
||||||
|
criterion=criterion,
|
||||||
kernel=kernel,
|
kernel=kernel,
|
||||||
max_iter=max_iter,
|
max_iter=max_iter,
|
||||||
tol=tol,
|
tol=tol,
|
||||||
degree=degree,
|
degree=degree,
|
||||||
gamma=gamma,
|
gamma=gamma,
|
||||||
|
splitter=splitter,
|
||||||
|
max_features=max_features,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _initialize_random(self) -> np.random.mtrand.RandomState:
|
def _initialize_random(self) -> np.random.mtrand.RandomState:
|
||||||
|
@ -14,7 +14,7 @@ class Odte_test(unittest.TestCase):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def test_max_samples_bogus(self):
|
def test_max_samples_bogus(self):
|
||||||
values = [0, 3000, 1.1, 0.0, "hi!"]
|
values = [0, 3000, 1.1, 0.0, "duck"]
|
||||||
for max_samples in values:
|
for max_samples in values:
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
tclf = Odte(max_samples=max_samples)
|
tclf = Odte(max_samples=max_samples)
|
||||||
@ -75,7 +75,10 @@ class Odte_test(unittest.TestCase):
|
|||||||
X, y = load_dataset(self._random_state)
|
X, y = load_dataset(self._random_state)
|
||||||
expected = y
|
expected = y
|
||||||
tclf = Odte(
|
tclf = Odte(
|
||||||
random_state=self._random_state, n_estimators=10, kernel="linear"
|
random_state=self._random_state,
|
||||||
|
max_features=None,
|
||||||
|
kernel="linear",
|
||||||
|
max_samples=0.1,
|
||||||
)
|
)
|
||||||
computed = tclf.fit(X, y).predict(X)
|
computed = tclf.fit(X, y).predict(X)
|
||||||
self.assertListEqual(expected[:27].tolist(), computed[:27].tolist())
|
self.assertListEqual(expected[:27].tolist(), computed[:27].tolist())
|
||||||
@ -83,6 +86,23 @@ class Odte_test(unittest.TestCase):
|
|||||||
def test_score(self):
|
def test_score(self):
|
||||||
X, y = load_dataset(self._random_state)
|
X, y = load_dataset(self._random_state)
|
||||||
expected = 0.9526666666666667
|
expected = 0.9526666666666667
|
||||||
tclf = Odte(random_state=self._random_state, n_estimators=10)
|
tclf = Odte(
|
||||||
|
random_state=self._random_state, max_features=None, n_estimators=10
|
||||||
|
)
|
||||||
|
computed = tclf.fit(X, y).score(X, y)
|
||||||
|
self.assertAlmostEqual(expected, computed)
|
||||||
|
|
||||||
|
def test_score_splitter_max_features(self):
|
||||||
|
X, y = load_dataset(self._random_state, n_features=12, n_samples=150)
|
||||||
|
results = [1.0, 0.94, 0.9933333333333333, 0.9933333333333333]
|
||||||
|
for max_features in ["auto", None]:
|
||||||
|
for splitter in ["best", "random"]:
|
||||||
|
tclf = Odte(
|
||||||
|
random_state=self._random_state,
|
||||||
|
splitter=splitter,
|
||||||
|
max_features=max_features,
|
||||||
|
n_estimators=10,
|
||||||
|
)
|
||||||
|
expected = results.pop(0)
|
||||||
computed = tclf.fit(X, y).score(X, y)
|
computed = tclf.fit(X, y).score(X, y)
|
||||||
self.assertAlmostEqual(expected, computed)
|
self.assertAlmostEqual(expected, computed)
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
from sklearn.datasets import make_classification
|
from sklearn.datasets import make_classification
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(random_state=0, n_classes=2):
|
def load_dataset(random_state=0, n_classes=2, n_features=3, n_samples=1500):
|
||||||
X, y = make_classification(
|
X, y = make_classification(
|
||||||
n_samples=1500,
|
n_samples=n_samples,
|
||||||
n_features=3,
|
n_features=n_features,
|
||||||
n_informative=3,
|
n_informative=3,
|
||||||
n_redundant=0,
|
n_redundant=0,
|
||||||
n_repeated=0,
|
n_repeated=0,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user