Compare commits

...

8 Commits

Author SHA1 Message Date
9cb69ebc75 Implement hyperparam. context based normalization 2021-04-15 02:13:30 +02:00
b55f59a3ec Fix compute number of nodes 2021-04-13 22:31:05 +02:00
783d105099 Add another nodes, leaves test 2021-04-09 10:56:54 +02:00
c36f685263 Fix unintended nested if in partition 2021-04-08 08:27:31 +02:00
0f89b044f1 Refactor train method 2021-04-07 01:02:30 +02:00
Ricardo Montañana Gómez
6ba973dfe1 Add a method that return nodes and leaves (#27) (#30)
Add a test
Fix #27
2021-03-23 14:30:32 +01:00
Ricardo Montañana Gómez
460c63a6d0 Fix depth sometimes is wrong (#26) (#29)
Add a test to the tests set
Add depth to node description
Fix iterator and str test due to this addon
2021-03-23 14:08:53 +01:00
Ricardo Montañana Gómez
f438124057 Fix mistakes (#24) (#28)
Put pandas requirements in notebooks
clean requirements.txt
2021-03-23 13:27:32 +01:00
9 changed files with 297 additions and 749 deletions

View File

@@ -20,15 +20,13 @@ pip install git+https://github.com/doctorado-ml/stree
- [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/Doctorado-ML/STree/master?urlpath=lab/tree/notebooks/benchmark.ipynb) Benchmark - [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/Doctorado-ML/STree/master?urlpath=lab/tree/notebooks/benchmark.ipynb) Benchmark
- [![Test](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Doctorado-ML/STree/blob/master/notebooks/benchmark.ipynb) Benchmark - [![benchmark](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Doctorado-ML/STree/blob/master/notebooks/benchmark.ipynb) Benchmark
- [![Test2](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Doctorado-ML/STree/blob/master/notebooks/features.ipynb) Test features - [![features](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Doctorado-ML/STree/blob/master/notebooks/features.ipynb) Some features
- [![Adaboost](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Doctorado-ML/STree/blob/master/notebooks/adaboost.ipynb) Adaboost
- [![Gridsearch](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Doctorado-ML/STree/blob/master/notebooks/gridsearch.ipynb) Gridsearch - [![Gridsearch](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Doctorado-ML/STree/blob/master/notebooks/gridsearch.ipynb) Gridsearch
- [![Test Graphics](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Doctorado-ML/STree/blob/master/notebooks/test_graphs.ipynb) Test Graphics - [![Ensemble](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Doctorado-ML/STree/blob/master/notebooks/ensemble.ipynb) Ensembles
## Hyperparameters ## Hyperparameters

View File

@@ -17,23 +17,25 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"#\n", "#\n",
"# Google Colab setup\n", "# Google Colab setup\n",
"#\n", "#\n",
"#!pip install git+https://github.com/doctorado-ml/stree" "#!pip install git+https://github.com/doctorado-ml/stree\n",
"!pip install pandas"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import datetime, time\n", "import datetime, time\n",
"import os\n",
"import numpy as np\n", "import numpy as np\n",
"import pandas as pd\n", "import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import train_test_split\n",
@@ -47,11 +49,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"if not os.path.isfile('data/creditcard.csv'):\n", "if not os.path.isfile('data/creditcard.csv'):\n",
" !wget --no-check-certificate --content-disposition http://nube.jccm.es/index.php/s/Zs7SYtZQJ3RQ2H2/download\n", " !wget --no-check-certificate --content-disposition http://nube.jccm.es/index.php/s/Zs7SYtZQJ3RQ2H2/download\n",
" !tar xzf creditcard.tgz" " !tar xzf creditcard.tgz"
@@ -66,19 +67,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"2021-01-14 11:30:51\n"
]
}
],
"source": [ "source": [
"print(datetime.date.today(), time.strftime(\"%H:%M:%S\"))" "print(datetime.date.today(), time.strftime(\"%H:%M:%S\"))"
] ]
@@ -92,7 +85,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -104,20 +97,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fraud: 0.173% 492\n",
"Valid: 99.827% 284,315\n"
]
}
],
"source": [ "source": [
"print(\"Fraud: {0:.3f}% {1}\".format(df.Class[df.Class == 1].count()*100/df.shape[0], df.Class[df.Class == 1].count()))\n", "print(\"Fraud: {0:.3f}% {1}\".format(df.Class[df.Class == 1].count()*100/df.shape[0], df.Class[df.Class == 1].count()))\n",
"print(\"Valid: {0:.3f}% {1:,}\".format(df.Class[df.Class == 0].count()*100/df.shape[0], df.Class[df.Class == 0].count()))" "print(\"Valid: {0:.3f}% {1:,}\".format(df.Class[df.Class == 0].count()*100/df.shape[0], df.Class[df.Class == 0].count()))"
@@ -125,7 +109,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -137,20 +121,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"X shape: (284807, 29)\n",
"y shape: (284807,)\n"
]
}
],
"source": [ "source": [
"# Remove unneeded features\n", "# Remove unneeded features\n",
"y = df.Class.values\n", "y = df.Class.values\n",
@@ -167,7 +142,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -178,7 +153,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -188,7 +163,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -198,7 +173,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -208,7 +183,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -218,7 +193,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -235,7 +210,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -260,194 +235,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"************************** Linear Tree **********************\n",
"Train Model Linear Tree took: 10.25 seconds\n",
"=========== Linear Tree - Train 199,364 samples =============\n",
" precision recall f1-score support\n",
"\n",
" 0 1.000000 1.000000 1.000000 199020\n",
" 1 1.000000 1.000000 1.000000 344\n",
"\n",
" accuracy 1.000000 199364\n",
" macro avg 1.000000 1.000000 1.000000 199364\n",
"weighted avg 1.000000 1.000000 1.000000 199364\n",
"\n",
"=========== Linear Tree - Test 85,443 samples =============\n",
" precision recall f1-score support\n",
"\n",
" 0 0.999578 0.999613 0.999596 85295\n",
" 1 0.772414 0.756757 0.764505 148\n",
"\n",
" accuracy 0.999192 85443\n",
" macro avg 0.885996 0.878185 0.882050 85443\n",
"weighted avg 0.999184 0.999192 0.999188 85443\n",
"\n",
"Confusion Matrix in Train\n",
"[[199020 0]\n",
" [ 0 344]]\n",
"Confusion Matrix in Test\n",
"[[85262 33]\n",
" [ 36 112]]\n",
"************************** Naive Bayes **********************\n",
"Train Model Naive Bayes took: 0.09943 seconds\n",
"=========== Naive Bayes - Train 199,364 samples =============\n",
" precision recall f1-score support\n",
"\n",
" 0 0.999692 0.978238 0.988849 199020\n",
" 1 0.061538 0.825581 0.114539 344\n",
"\n",
" accuracy 0.977975 199364\n",
" macro avg 0.530615 0.901910 0.551694 199364\n",
"weighted avg 0.998073 0.977975 0.987340 199364\n",
"\n",
"=========== Naive Bayes - Test 85,443 samples =============\n",
" precision recall f1-score support\n",
"\n",
" 0 0.999712 0.977994 0.988734 85295\n",
" 1 0.061969 0.837838 0.115403 148\n",
"\n",
" accuracy 0.977751 85443\n",
" macro avg 0.530841 0.907916 0.552068 85443\n",
"weighted avg 0.998088 0.977751 0.987221 85443\n",
"\n",
"Confusion Matrix in Train\n",
"[[194689 4331]\n",
" [ 60 284]]\n",
"Confusion Matrix in Test\n",
"[[83418 1877]\n",
" [ 24 124]]\n",
"************************** Stree (SVM Tree) **********************\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/.virtualenvs/general/lib/python3.8/site-packages/sklearn/svm/_base.py:976: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\"Liblinear failed to converge, increase \"\n",
"/Users/rmontanana/.virtualenvs/general/lib/python3.8/site-packages/sklearn/svm/_base.py:976: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\"Liblinear failed to converge, increase \"\n",
"/Users/rmontanana/.virtualenvs/general/lib/python3.8/site-packages/sklearn/svm/_base.py:976: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\"Liblinear failed to converge, increase \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Model Stree (SVM Tree) took: 28.47 seconds\n",
"=========== Stree (SVM Tree) - Train 199,364 samples =============\n",
" precision recall f1-score support\n",
"\n",
" 0 0.999623 0.999864 0.999744 199020\n",
" 1 0.908784 0.781977 0.840625 344\n",
"\n",
" accuracy 0.999488 199364\n",
" macro avg 0.954204 0.890921 0.920184 199364\n",
"weighted avg 0.999467 0.999488 0.999469 199364\n",
"\n",
"=========== Stree (SVM Tree) - Test 85,443 samples =============\n",
" precision recall f1-score support\n",
"\n",
" 0 0.999637 0.999918 0.999777 85295\n",
" 1 0.943548 0.790541 0.860294 148\n",
"\n",
" accuracy 0.999555 85443\n",
" macro avg 0.971593 0.895229 0.930036 85443\n",
"weighted avg 0.999540 0.999555 0.999536 85443\n",
"\n",
"Confusion Matrix in Train\n",
"[[198993 27]\n",
" [ 75 269]]\n",
"Confusion Matrix in Test\n",
"[[85288 7]\n",
" [ 31 117]]\n",
"************************** Neural Network **********************\n",
"Train Model Neural Network took: 9.76 seconds\n",
"=========== Neural Network - Train 199,364 samples =============\n",
" precision recall f1-score support\n",
"\n",
" 0 0.999247 0.999844 0.999545 199020\n",
" 1 0.862222 0.563953 0.681898 344\n",
"\n",
" accuracy 0.999092 199364\n",
" macro avg 0.930734 0.781899 0.840722 199364\n",
"weighted avg 0.999010 0.999092 0.998997 199364\n",
"\n",
"=========== Neural Network - Test 85,443 samples =============\n",
" precision recall f1-score support\n",
"\n",
" 0 0.999356 0.999871 0.999613 85295\n",
" 1 0.894231 0.628378 0.738095 148\n",
"\n",
" accuracy 0.999228 85443\n",
" macro avg 0.946793 0.814125 0.868854 85443\n",
"weighted avg 0.999173 0.999228 0.999160 85443\n",
"\n",
"Confusion Matrix in Train\n",
"[[198989 31]\n",
" [ 150 194]]\n",
"Confusion Matrix in Test\n",
"[[85284 11]\n",
" [ 55 93]]\n",
"************************** SVC (linear) **********************\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/.virtualenvs/general/lib/python3.8/site-packages/sklearn/svm/_base.py:976: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
" warnings.warn(\"Liblinear failed to converge, increase \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train Model SVC (linear) took: 8.207 seconds\n",
"=========== SVC (linear) - Train 199,364 samples =============\n",
" precision recall f1-score support\n",
"\n",
" 0 0.999237 0.999859 0.999548 199020\n",
" 1 0.872727 0.558140 0.680851 344\n",
"\n",
" accuracy 0.999097 199364\n",
" macro avg 0.935982 0.778999 0.840199 199364\n",
"weighted avg 0.999018 0.999097 0.998998 199364\n",
"\n",
"=========== SVC (linear) - Test 85,443 samples =============\n",
" precision recall f1-score support\n",
"\n",
" 0 0.999344 0.999894 0.999619 85295\n",
" 1 0.910891 0.621622 0.738956 148\n",
"\n",
" accuracy 0.999239 85443\n",
" macro avg 0.955117 0.810758 0.869287 85443\n",
"weighted avg 0.999191 0.999239 0.999168 85443\n",
"\n",
"Confusion Matrix in Train\n",
"[[198992 28]\n",
" [ 152 192]]\n",
"Confusion Matrix in Test\n",
"[[85286 9]\n",
" [ 56 92]]\n"
]
}
],
"source": [ "source": [
"# Train & Test models\n", "# Train & Test models\n",
"models = {\n", "models = {\n",
" 'Linear Tree':linear_tree, 'Naive Bayes': naive_bayes, 'Stree (SVM Tree)': stree, \n", " 'Linear Tree':linear_tree, 'Naive Bayes': naive_bayes, 'Stree ': stree, \n",
" 'Neural Network': mlp, 'SVC (linear)': svc\n", " 'Neural Network': mlp, 'SVC (linear)': svc\n",
"}\n", "}\n",
"\n", "\n",
@@ -464,26 +260,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"**************************************************************************************************************\n",
"*The best f1 model is Stree (SVM Tree), with a f1 score: 0.8603 in 28.4743 seconds with 0.7 samples in train dataset\n",
"**************************************************************************************************************\n",
"Model: Linear Tree\t Time: 10.25 seconds\t f1: 0.7645\n",
"Model: Naive Bayes\t Time: 0.10 seconds\t f1: 0.1154\n",
"Model: Stree (SVM Tree)\t Time: 28.47 seconds\t f1: 0.8603\n",
"Model: Neural Network\t Time: 9.76 seconds\t f1: 0.7381\n",
"Model: SVC (linear)\t Time: 8.21 seconds\t f1: 0.739\n"
]
}
],
"source": [ "source": [
"print(\"*\"*110)\n", "print(\"*\"*110)\n",
"print(f\"*The best f1 model is {best_model}, with a f1 score: {best_f1:.4} in {best_time:.6} seconds with {train_size:,} samples in train dataset\")\n", "print(f\"*The best f1 model is {best_model}, with a f1 score: {best_f1:.4} in {best_time:.6} seconds with {train_size:,} samples in train dataset\")\n",
@@ -508,32 +289,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"data": {
"text/plain": [
"{'C': 0.01,\n",
" 'criterion': 'entropy',\n",
" 'degree': 3,\n",
" 'gamma': 'scale',\n",
" 'kernel': 'linear',\n",
" 'max_depth': None,\n",
" 'max_features': None,\n",
" 'max_iter': 1000.0,\n",
" 'min_samples_split': 0,\n",
" 'random_state': 2020,\n",
" 'split_criteria': 'impurity',\n",
" 'splitter': 'random',\n",
" 'tol': 0.0001}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"stree.get_params()" "stree.get_params()"
] ]
@@ -556,7 +314,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.2" "version": "3.8.2-final"
}, },
"toc": { "toc": {
"base_numbering": 1, "base_numbering": 1,

View File

@@ -17,38 +17,43 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"#\n", "#\n",
"# Google Colab setup\n", "# Google Colab setup\n",
"#\n", "#\n",
"#!pip install git+https://github.com/doctorado-ml/stree" "#!pip install git+https://github.com/doctorado-ml/stree\n",
"!pip install pandas"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import time\n", "import time\n",
"import os\n",
"import random\n",
"import warnings\n", "import warnings\n",
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier\n", "from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier\n",
"from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import train_test_split\n",
"from sklearn.exceptions import ConvergenceWarning\n", "from sklearn.exceptions import ConvergenceWarning\n",
"from stree import Stree\n", "from stree import Stree\n",
"\n",
"warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)" "warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"if not os.path.isfile('data/creditcard.csv'):\n", "if not os.path.isfile('data/creditcard.csv'):\n",
" !wget --no-check-certificate --content-disposition http://nube.jccm.es/index.php/s/Zs7SYtZQJ3RQ2H2/download\n", " !wget --no-check-certificate --content-disposition http://nube.jccm.es/index.php/s/Zs7SYtZQJ3RQ2H2/download\n",
" !tar xzf creditcard.tgz" " !tar xzf creditcard.tgz"
@@ -56,30 +61,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fraud: 0.173% 492\n",
"Valid: 99.827% 284315\n",
"X.shape (100492, 28) y.shape (100492,)\n",
"Fraud: 0.651% 654\n",
"Valid: 99.349% 99838\n"
]
}
],
"source": [ "source": [
"random_state=1\n", "random_state=1\n",
"\n", "\n",
"def load_creditcard(n_examples=0):\n", "def load_creditcard(n_examples=0):\n",
" import pandas as pd\n",
" import numpy as np\n",
" import random\n",
" df = pd.read_csv('data/creditcard.csv')\n", " df = pd.read_csv('data/creditcard.csv')\n",
" print(\"Fraud: {0:.3f}% {1}\".format(df.Class[df.Class == 1].count()*100/df.shape[0], df.Class[df.Class == 1].count()))\n", " print(\"Fraud: {0:.3f}% {1}\".format(df.Class[df.Class == 1].count()*100/df.shape[0], df.Class[df.Class == 1].count()))\n",
" print(\"Valid: {0:.3f}% {1}\".format(df.Class[df.Class == 0].count()*100/df.shape[0], df.Class[df.Class == 0].count()))\n", " print(\"Valid: {0:.3f}% {1}\".format(df.Class[df.Class == 0].count()*100/df.shape[0], df.Class[df.Class == 0].count()))\n",
@@ -130,21 +120,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Score Train: 0.9984504719663368\n",
"Score Test: 0.9983415151917209\n",
"Took 26.09 seconds\n"
]
}
],
"source": [ "source": [
"now = time.time()\n", "now = time.time()\n",
"clf = Stree(max_depth=3, random_state=random_state, max_iter=1e3)\n", "clf = Stree(max_depth=3, random_state=random_state, max_iter=1e3)\n",
@@ -163,7 +143,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -174,21 +154,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Kernel: linear\tTime: 43.49 seconds\tScore Train: 0.9980098\tScore Test: 0.9980762\n",
"Kernel: rbf\tTime: 8.86 seconds\tScore Train: 0.9934891\tScore Test: 0.9934987\n",
"Kernel: poly\tTime: 41.14 seconds\tScore Train: 0.9972279\tScore Test: 0.9973133\n"
]
}
],
"source": [ "source": [
"for kernel in ['linear', 'rbf', 'poly']:\n", "for kernel in ['linear', 'rbf', 'poly']:\n",
" now = time.time()\n", " now = time.time()\n",
@@ -208,7 +178,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -219,21 +189,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Kernel: linear\tTime: 187.51 seconds\tScore Train: 0.9984505\tScore Test: 0.9983083\n",
"Kernel: rbf\tTime: 73.65 seconds\tScore Train: 0.9993461\tScore Test: 0.9985074\n",
"Kernel: poly\tTime: 52.19 seconds\tScore Train: 0.9993461\tScore Test: 0.9987727\n"
]
}
],
"source": [ "source": [
"for kernel in ['linear', 'rbf', 'poly']:\n", "for kernel in ['linear', 'rbf', 'poly']:\n",
" now = time.time()\n", " now = time.time()\n",
@@ -261,7 +221,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.2" "version": "3.8.2-final"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -17,24 +17,27 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"#\n", "#\n",
"# Google Colab setup\n", "# Google Colab setup\n",
"#\n", "#\n",
"#!pip install git+https://github.com/doctorado-ml/stree" "#!pip install git+https://github.com/doctorado-ml/stree\n",
"!pip install pandas"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import time\n", "import time\n",
"import random\n",
"import warnings\n", "import warnings\n",
"import os\n",
"import numpy as np\n", "import numpy as np\n",
"import pandas as pd\n", "import pandas as pd\n",
"from sklearn.svm import SVC\n", "from sklearn.svm import SVC\n",
@@ -42,6 +45,7 @@
"from sklearn.utils.estimator_checks import check_estimator\n", "from sklearn.utils.estimator_checks import check_estimator\n",
"from sklearn.datasets import make_classification, load_iris, load_wine\n", "from sklearn.datasets import make_classification, load_iris, load_wine\n",
"from sklearn.model_selection import train_test_split\n", "from sklearn.model_selection import train_test_split\n",
"from sklearn.utils.class_weight import compute_sample_weight\n",
"from sklearn.exceptions import ConvergenceWarning\n", "from sklearn.exceptions import ConvergenceWarning\n",
"from stree import Stree\n", "from stree import Stree\n",
"warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)" "warnings.filterwarnings(\"ignore\", category=ConvergenceWarning)"
@@ -49,13 +53,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"if not os.path.isfile('data/creditcard.csv'):\n", "if not os.path.isfile('data/creditcard.csv'):\n",
" !wget --no-check-certificate --content-disposition http://nube.jccm.es/index.php/s/Zs7SYtZQJ3RQ2H2/download\n", " !wget --no-check-certificate --content-disposition http://nube.jccm.es/index.php/s/Zs7SYtZQJ3RQ2H2/download\n",
" !tar xzf creditcard.tgz" " !tar xzf creditcard.tgz"
@@ -63,31 +66,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fraud: 0.173% 492\n",
"Valid: 99.827% 284315\n",
"X.shape (5492, 28) y.shape (5492,)\n",
"Fraud: 9.086% 499\n",
"Valid: 90.914% 4993\n",
"[0.09079084 0.09079084 0.09079084 0.09079084] [0.09101942 0.09101942 0.09101942 0.09101942]\n"
]
}
],
"source": [ "source": [
"random_state=1\n", "random_state=1\n",
"\n", "\n",
"def load_creditcard(n_examples=0):\n", "def load_creditcard(n_examples=0):\n",
" import pandas as pd\n",
" import numpy as np\n",
" import random\n",
" df = pd.read_csv('data/creditcard.csv')\n", " df = pd.read_csv('data/creditcard.csv')\n",
" print(\"Fraud: {0:.3f}% {1}\".format(df.Class[df.Class == 1].count()*100/df.shape[0], df.Class[df.Class == 1].count()))\n", " print(\"Fraud: {0:.3f}% {1}\".format(df.Class[df.Class == 1].count()*100/df.shape[0], df.Class[df.Class == 1].count()))\n",
" print(\"Valid: {0:.3f}% {1}\".format(df.Class[df.Class == 0].count()*100/df.shape[0], df.Class[df.Class == 0].count()))\n", " print(\"Valid: {0:.3f}% {1}\".format(df.Class[df.Class == 0].count()*100/df.shape[0], df.Class[df.Class == 0].count()))\n",
@@ -119,17 +106,8 @@
"Xtest = data[1]\n", "Xtest = data[1]\n",
"ytrain = data[2]\n", "ytrain = data[2]\n",
"ytest = data[3]\n", "ytest = data[3]\n",
"_, data = np.unique(ytrain, return_counts=True)\n", "weights = compute_sample_weight(\"balanced\", ytrain)\n",
"wtrain = (data[1] / np.sum(data), data[0] / np.sum(data))\n", "weights_test = compute_sample_weight(\"balanced\", ytest)\n",
"_, data = np.unique(ytest, return_counts=True)\n",
"wtest = (data[1] / np.sum(data), data[0] / np.sum(data))\n",
"# Set weights inverse to its count class in dataset\n",
"weights = np.ones(Xtrain.shape[0],)\n",
"weights[ytrain==0] = wtrain[0]\n",
"weights[ytrain==1] = wtrain[1]\n",
"weights_test = np.ones(Xtest.shape[0],)\n",
"weights_test[ytest==0] = wtest[0]\n",
"weights_test[ytest==1] = wtest[1]\n",
"print(weights[:4], weights_test[:4])" "print(weights[:4], weights_test[:4])"
] ]
}, },
@@ -150,22 +128,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy of Train without weights 0.9849115504682622\n",
"Accuracy of Train with weights 0.9849115504682622\n",
"Accuracy of Tests without weights 0.9848300970873787\n",
"Accuracy of Tests with weights 0.9805825242718447\n"
]
}
],
"source": [ "source": [
"C = 23\n", "C = 23\n",
"print(\"Accuracy of Train without weights\", Stree(C=C, random_state=1).fit(Xtrain, ytrain).score(Xtrain, ytrain))\n", "print(\"Accuracy of Train without weights\", Stree(C=C, random_state=1).fit(Xtrain, ytrain).score(Xtrain, ytrain))\n",
@@ -184,21 +151,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time: 26.59s\tKernel: linear\tAccuracy_train: 0.9846514047866806\tAccuracy_test: 0.9848300970873787\n",
"Time: 0.56s\tKernel: rbf\tAccuracy_train: 0.9947970863683663\tAccuracy_test: 0.9866504854368932\n",
"Time: 0.23s\tKernel: poly\tAccuracy_train: 0.9955775234131113\tAccuracy_test: 0.9824029126213593\n"
]
}
],
"source": [ "source": [
"random_state=1\n", "random_state=1\n",
"for kernel in ['linear', 'rbf', 'poly']:\n", "for kernel in ['linear', 'rbf', 'poly']:\n",
@@ -219,77 +176,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"************** C=0.001 ****************************\n",
"Classifier's accuracy (train): 0.9823\n",
"Classifier's accuracy (test) : 0.9836\n",
"root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4391 counts=(array([0, 1]), array([3495, 349]))\n",
"root - Down, <cgaf> - Leaf class=0 belief= 0.981455 impurity=0.1332 counts=(array([0, 1]), array([3493, 66]))\n",
"root - Up, <cgaf> - Leaf class=1 belief= 0.992982 impurity=0.0603 counts=(array([0, 1]), array([ 2, 283]))\n",
"\n",
"**************************************************\n",
"************** C=0.01 ****************************\n",
"Classifier's accuracy (train): 0.9834\n",
"Classifier's accuracy (test) : 0.9842\n",
"root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4391 counts=(array([0, 1]), array([3495, 349]))\n",
"root - Down, <cgaf> - Leaf class=0 belief= 0.982288 impurity=0.1284 counts=(array([0, 1]), array([3494, 63]))\n",
"root - Up, <cgaf> - Leaf class=1 belief= 0.996516 impurity=0.0335 counts=(array([0, 1]), array([ 1, 286]))\n",
"\n",
"**************************************************\n",
"************** C=1 ****************************\n",
"Classifier's accuracy (train): 0.9844\n",
"Classifier's accuracy (test) : 0.9848\n",
"root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4391 counts=(array([0, 1]), array([3495, 349]))\n",
"root - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.1236 counts=(array([0, 1]), array([3493, 60]))\n",
"root - Down - Down, <cgaf> - Leaf class=0 belief= 0.983108 impurity=0.1236 counts=(array([0, 1]), array([3492, 60]))\n",
"root - Down - Up, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([1]))\n",
"root - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0593 counts=(array([0, 1]), array([ 2, 289]))\n",
"root - Up - Down, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([2]))\n",
"root - Up - Up, <pure> - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([289]))\n",
"\n",
"**************************************************\n",
"************** C=5 ****************************\n",
"Classifier's accuracy (train): 0.9847\n",
"Classifier's accuracy (test) : 0.9848\n",
"root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4391 counts=(array([0, 1]), array([3495, 349]))\n",
"root - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.1236 counts=(array([0, 1]), array([3493, 60]))\n",
"root - Down - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.1236 counts=(array([0, 1]), array([3492, 60]))\n",
"root - Down - Down - Down, <cgaf> - Leaf class=0 belief= 0.983385 impurity=0.1220 counts=(array([0, 1]), array([3492, 59]))\n",
"root - Down - Down - Up, <pure> - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([1]))\n",
"root - Down - Up, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([1]))\n",
"root - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0593 counts=(array([0, 1]), array([ 2, 289]))\n",
"root - Up - Down, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([2]))\n",
"root - Up - Up, <pure> - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([289]))\n",
"\n",
"**************************************************\n",
"************** C=17 ****************************\n",
"Classifier's accuracy (train): 0.9847\n",
"Classifier's accuracy (test) : 0.9848\n",
"root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4391 counts=(array([0, 1]), array([3495, 349]))\n",
"root - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.1236 counts=(array([0, 1]), array([3493, 60]))\n",
"root - Down - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.1220 counts=(array([0, 1]), array([3492, 59]))\n",
"root - Down - Down - Down, <cgaf> - Leaf class=0 belief= 0.983380 impurity=0.1220 counts=(array([0, 1]), array([3491, 59]))\n",
"root - Down - Down - Up, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([1]))\n",
"root - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=1.0000 counts=(array([0, 1]), array([1, 1]))\n",
"root - Down - Up - Down, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([1]))\n",
"root - Down - Up - Up, <pure> - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([1]))\n",
"root - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0593 counts=(array([0, 1]), array([ 2, 289]))\n",
"root - Up - Down, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([2]))\n",
"root - Up - Up, <pure> - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([289]))\n",
"\n",
"**************************************************\n",
"59.0161 secs\n"
]
}
],
"source": [ "source": [
"t = time.time()\n", "t = time.time()\n",
"for C in (.001, .01, 1, 5, 17):\n", "for C in (.001, .01, 1, 5, 17):\n",
@@ -313,29 +204,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4391 counts=(array([0, 1]), array([3495, 349]))\n",
"root - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.1236 counts=(array([0, 1]), array([3493, 60]))\n",
"root - Down - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.1220 counts=(array([0, 1]), array([3492, 59]))\n",
"root - Down - Down - Down, <cgaf> - Leaf class=0 belief= 0.983380 impurity=0.1220 counts=(array([0, 1]), array([3491, 59]))\n",
"root - Down - Down - Up, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([1]))\n",
"root - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=1.0000 counts=(array([0, 1]), array([1, 1]))\n",
"root - Down - Up - Down, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([1]))\n",
"root - Down - Up - Up, <pure> - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([1]))\n",
"root - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0593 counts=(array([0, 1]), array([ 2, 289]))\n",
"root - Up - Down, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([2]))\n",
"root - Up - Up, <pure> - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([289]))\n"
]
}
],
"source": [ "source": [
"#check iterator\n", "#check iterator\n",
"for i in list(clf):\n", "for i in list(clf):\n",
@@ -344,29 +217,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"root feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.4391 counts=(array([0, 1]), array([3495, 349]))\n",
"root - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.1236 counts=(array([0, 1]), array([3493, 60]))\n",
"root - Down - Down feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.1220 counts=(array([0, 1]), array([3492, 59]))\n",
"root - Down - Down - Down, <cgaf> - Leaf class=0 belief= 0.983380 impurity=0.1220 counts=(array([0, 1]), array([3491, 59]))\n",
"root - Down - Down - Up, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([1]))\n",
"root - Down - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=1.0000 counts=(array([0, 1]), array([1, 1]))\n",
"root - Down - Up - Down, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([1]))\n",
"root - Down - Up - Up, <pure> - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([1]))\n",
"root - Up feaures=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27) impurity=0.0593 counts=(array([0, 1]), array([ 2, 289]))\n",
"root - Up - Down, <pure> - Leaf class=0 belief= 1.000000 impurity=0.0000 counts=(array([0]), array([2]))\n",
"root - Up - Up, <pure> - Leaf class=1 belief= 1.000000 impurity=0.0000 counts=(array([1]), array([289]))\n"
]
}
],
"source": [ "source": [
"#check iterator again\n", "#check iterator again\n",
"for i in clf:\n", "for i in clf:\n",
@@ -382,73 +237,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": null,
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 functools.partial(<function check_no_attributes_set_in_init at 0x16817f670>, 'Stree')\n",
"2 functools.partial(<function check_estimators_dtypes at 0x168179820>, 'Stree')\n",
"3 functools.partial(<function check_fit_score_takes_y at 0x168179700>, 'Stree')\n",
"4 functools.partial(<function check_sample_weights_pandas_series at 0x168174040>, 'Stree')\n",
"5 functools.partial(<function check_sample_weights_not_an_array at 0x168174160>, 'Stree')\n",
"6 functools.partial(<function check_sample_weights_list at 0x168174280>, 'Stree')\n",
"7 functools.partial(<function check_sample_weights_shape at 0x1681743a0>, 'Stree')\n",
"8 functools.partial(<function check_sample_weights_invariance at 0x1681744c0>, 'Stree', kind='ones')\n",
"10 functools.partial(<function check_estimators_fit_returns_self at 0x16817b8b0>, 'Stree')\n",
"11 functools.partial(<function check_estimators_fit_returns_self at 0x16817b8b0>, 'Stree', readonly_memmap=True)\n",
"12 functools.partial(<function check_complex_data at 0x168174670>, 'Stree')\n",
"13 functools.partial(<function check_dtype_object at 0x1681745e0>, 'Stree')\n",
"14 functools.partial(<function check_estimators_empty_data_messages at 0x1681799d0>, 'Stree')\n",
"15 functools.partial(<function check_pipeline_consistency at 0x1681795e0>, 'Stree')\n",
"16 functools.partial(<function check_estimators_nan_inf at 0x168179af0>, 'Stree')\n",
"17 functools.partial(<function check_estimators_overwrite_params at 0x16817f550>, 'Stree')\n",
"18 functools.partial(<function check_estimator_sparse_data at 0x168172ee0>, 'Stree')\n",
"19 functools.partial(<function check_estimators_pickle at 0x168179d30>, 'Stree')\n",
"20 functools.partial(<function check_estimator_get_tags_default_keys at 0x168181790>, 'Stree')\n",
"21 functools.partial(<function check_classifier_data_not_an_array at 0x16817f8b0>, 'Stree')\n",
"22 functools.partial(<function check_classifiers_one_label at 0x16817b430>, 'Stree')\n",
"23 functools.partial(<function check_classifiers_classes at 0x16817bd30>, 'Stree')\n",
"24 functools.partial(<function check_estimators_partial_fit_n_features at 0x168179e50>, 'Stree')\n",
"25 functools.partial(<function check_classifiers_train at 0x16817b550>, 'Stree')\n",
"26 functools.partial(<function check_classifiers_train at 0x16817b550>, 'Stree', readonly_memmap=True)\n",
"27 functools.partial(<function check_classifiers_train at 0x16817b550>, 'Stree', readonly_memmap=True, X_dtype='float32')\n",
"28 functools.partial(<function check_classifiers_regression_target at 0x168181280>, 'Stree')\n",
"29 functools.partial(<function check_supervised_y_no_nan at 0x1681720d0>, 'Stree')\n",
"30 functools.partial(<function check_supervised_y_2d at 0x16817baf0>, 'Stree')\n",
"31 functools.partial(<function check_estimators_unfitted at 0x16817b9d0>, 'Stree')\n",
"32 functools.partial(<function check_non_transformer_estimators_n_iter at 0x16817fdc0>, 'Stree')\n",
"33 functools.partial(<function check_decision_proba_consistency at 0x1681813a0>, 'Stree')\n",
"34 functools.partial(<function check_parameters_default_constructible at 0x16817fb80>, 'Stree')\n",
"35 functools.partial(<function check_methods_sample_order_invariance at 0x168174d30>, 'Stree')\n",
"36 functools.partial(<function check_methods_subset_invariance at 0x168174c10>, 'Stree')\n",
"37 functools.partial(<function check_fit2d_1sample at 0x168174e50>, 'Stree')\n",
"38 functools.partial(<function check_fit2d_1feature at 0x168174f70>, 'Stree')\n",
"39 functools.partial(<function check_get_params_invariance at 0x168181040>, 'Stree')\n",
"40 functools.partial(<function check_set_params at 0x168181160>, 'Stree')\n",
"41 functools.partial(<function check_dict_unchanged at 0x168174790>, 'Stree')\n",
"42 functools.partial(<function check_dont_overwrite_parameters at 0x168174940>, 'Stree')\n",
"43 functools.partial(<function check_fit_idempotent at 0x168181550>, 'Stree')\n",
"44 functools.partial(<function check_n_features_in at 0x1681815e0>, 'Stree')\n",
"45 functools.partial(<function check_fit1d at 0x1681790d0>, 'Stree')\n",
"46 functools.partial(<function check_fit2d_predict1d at 0x168174a60>, 'Stree')\n",
"47 functools.partial(<function check_requires_y_none at 0x168181670>, 'Stree')\n"
]
}
],
"source": [ "source": [
"# Make checks one by one\n", "# Make checks one by one\n",
"c = 0\n", "c = 0\n",
"checks = check_estimator(Stree(), generate_only=True)\n", "checks = check_estimator(Stree(), generate_only=True)\n",
"for check in checks:\n", "for check in checks:\n",
" c += 1\n", " c += 1\n",
" if c == 9:\n",
" pass\n",
" else:\n",
" print(c, check[1])\n", " print(c, check[1])\n",
" check[1](check[0])" " check[1](check[0])"
] ]
@@ -552,7 +351,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.8.2-final"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -18,19 +18,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"#\n", "#\n",
"# Google Colab setup\n", "# Google Colab setup\n",
"#\n", "#\n",
"#!pip install git+https://github.com/doctorado-ml/stree" "#!pip install git+https://github.com/doctorado-ml/stree\n",
"!pip install pandas"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
@@ -38,6 +39,10 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import random\n",
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn.ensemble import AdaBoostClassifier\n", "from sklearn.ensemble import AdaBoostClassifier\n",
"from sklearn.svm import LinearSVC\n", "from sklearn.svm import LinearSVC\n",
"from sklearn.model_selection import GridSearchCV, train_test_split\n", "from sklearn.model_selection import GridSearchCV, train_test_split\n",
@@ -46,7 +51,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
@@ -54,7 +59,6 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"if not os.path.isfile('data/creditcard.csv'):\n", "if not os.path.isfile('data/creditcard.csv'):\n",
" !wget --no-check-certificate --content-disposition http://nube.jccm.es/index.php/s/Zs7SYtZQJ3RQ2H2/download\n", " !wget --no-check-certificate --content-disposition http://nube.jccm.es/index.php/s/Zs7SYtZQJ3RQ2H2/download\n",
" !tar xzf creditcard.tgz" " !tar xzf creditcard.tgz"
@@ -62,7 +66,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
@@ -70,26 +74,11 @@
"outputId": "afc822fb-f16a-4302-8a67-2b9e2880159b", "outputId": "afc822fb-f16a-4302-8a67-2b9e2880159b",
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fraud: 0.173% 492\n",
"Valid: 99.827% 284315\n",
"X.shape (1492, 28) y.shape (1492,)\n",
"Fraud: 33.177% 495\n",
"Valid: 66.823% 997\n"
]
}
],
"source": [ "source": [
"random_state=1\n", "random_state=1\n",
"\n", "\n",
"def load_creditcard(n_examples=0):\n", "def load_creditcard(n_examples=0):\n",
" import pandas as pd\n",
" import numpy as np\n",
" import random\n",
" df = pd.read_csv('data/creditcard.csv')\n", " df = pd.read_csv('data/creditcard.csv')\n",
" print(\"Fraud: {0:.3f}% {1}\".format(df.Class[df.Class == 1].count()*100/df.shape[0], df.Class[df.Class == 1].count()))\n", " print(\"Fraud: {0:.3f}% {1}\".format(df.Class[df.Class == 1].count()*100/df.shape[0], df.Class[df.Class == 1].count()))\n",
" print(\"Valid: {0:.3f}% {1}\".format(df.Class[df.Class == 0].count()*100/df.shape[0], df.Class[df.Class == 0].count()))\n", " print(\"Valid: {0:.3f}% {1}\".format(df.Class[df.Class == 0].count()*100/df.shape[0], df.Class[df.Class == 0].count()))\n",
@@ -132,7 +121,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
@@ -176,39 +165,16 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"data": {
"text/plain": [
"{'C': 1.0,\n",
" 'criterion': 'entropy',\n",
" 'degree': 3,\n",
" 'gamma': 'scale',\n",
" 'kernel': 'linear',\n",
" 'max_depth': None,\n",
" 'max_features': None,\n",
" 'max_iter': 100000.0,\n",
" 'min_samples_split': 0,\n",
" 'random_state': None,\n",
" 'split_criteria': 'impurity',\n",
" 'splitter': 'random',\n",
" 'tol': 0.0001}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"Stree().get_params()" "Stree().get_params()"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
@@ -216,69 +182,7 @@
"outputId": "7703413a-d563-4289-a13b-532f38f82762", "outputId": "7703413a-d563-4289-a13b-532f38f82762",
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fitting 5 folds for each of 1008 candidates, totalling 5040 fits\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[Parallel(n_jobs=-1)]: Using backend LokyBackend with 16 concurrent workers.\n",
"[Parallel(n_jobs=-1)]: Done 40 tasks | elapsed: 1.6s\n",
"[Parallel(n_jobs=-1)]: Done 130 tasks | elapsed: 3.1s\n",
"[Parallel(n_jobs=-1)]: Done 256 tasks | elapsed: 5.5s\n",
"[Parallel(n_jobs=-1)]: Done 418 tasks | elapsed: 9.3s\n",
"[Parallel(n_jobs=-1)]: Done 616 tasks | elapsed: 18.6s\n",
"[Parallel(n_jobs=-1)]: Done 850 tasks | elapsed: 28.2s\n",
"[Parallel(n_jobs=-1)]: Done 1120 tasks | elapsed: 35.4s\n",
"[Parallel(n_jobs=-1)]: Done 1426 tasks | elapsed: 43.5s\n",
"[Parallel(n_jobs=-1)]: Done 1768 tasks | elapsed: 51.3s\n",
"[Parallel(n_jobs=-1)]: Done 2146 tasks | elapsed: 1.0min\n",
"[Parallel(n_jobs=-1)]: Done 2560 tasks | elapsed: 1.2min\n",
"[Parallel(n_jobs=-1)]: Done 3010 tasks | elapsed: 1.4min\n",
"[Parallel(n_jobs=-1)]: Done 3496 tasks | elapsed: 1.7min\n",
"[Parallel(n_jobs=-1)]: Done 4018 tasks | elapsed: 2.1min\n",
"[Parallel(n_jobs=-1)]: Done 4576 tasks | elapsed: 2.6min\n",
"[Parallel(n_jobs=-1)]: Done 5040 out of 5040 | elapsed: 2.9min finished\n"
]
},
{
"data": {
"text/plain": [
"GridSearchCV(estimator=AdaBoostClassifier(algorithm='SAMME', random_state=1),\n",
" n_jobs=-1,\n",
" param_grid=[{'base_estimator': [Stree(C=55, max_depth=7,\n",
" random_state=1,\n",
" split_criteria='max_samples',\n",
" tol=0.1)],\n",
" 'base_estimator__C': [1, 7, 55],\n",
" 'base_estimator__kernel': ['linear'],\n",
" 'base_estimator__max_depth': [3, 5, 7],\n",
" 'base_estimator__split_criteria': ['max_samples',\n",
" 'impuri...\n",
" {'base_estimator': [Stree(random_state=1)],\n",
" 'base_estimator__C': [1, 7, 55],\n",
" 'base_estimator__gamma': [0.1, 1, 10],\n",
" 'base_estimator__kernel': ['rbf'],\n",
" 'base_estimator__max_depth': [3, 5, 7],\n",
" 'base_estimator__split_criteria': ['max_samples',\n",
" 'impurity'],\n",
" 'base_estimator__tol': [0.1, 0.01],\n",
" 'learning_rate': [0.5, 1],\n",
" 'n_estimators': [10, 25]}],\n",
" return_train_score=True, verbose=5)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"clf = AdaBoostClassifier(random_state=random_state, algorithm=\"SAMME\")\n", "clf = AdaBoostClassifier(random_state=random_state, algorithm=\"SAMME\")\n",
"grid = GridSearchCV(clf, parameters, verbose=5, n_jobs=-1, return_train_score=True)\n", "grid = GridSearchCV(clf, parameters, verbose=5, n_jobs=-1, return_train_score=True)\n",
@@ -287,7 +191,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": null,
"metadata": { "metadata": {
"colab": {}, "colab": {},
"colab_type": "code", "colab_type": "code",
@@ -295,20 +199,7 @@
"outputId": "285163c8-fa33-4915-8ae7-61c4f7844344", "outputId": "285163c8-fa33-4915-8ae7-61c4f7844344",
"tags": [] "tags": []
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best estimator: AdaBoostClassifier(algorithm='SAMME',\n",
" base_estimator=Stree(C=55, max_depth=7, random_state=1,\n",
" split_criteria='max_samples', tol=0.1),\n",
" learning_rate=0.5, n_estimators=25, random_state=1)\n",
"Best hyperparameters: {'base_estimator': Stree(C=55, max_depth=7, random_state=1, split_criteria='max_samples', tol=0.1), 'base_estimator__C': 55, 'base_estimator__kernel': 'linear', 'base_estimator__max_depth': 7, 'base_estimator__split_criteria': 'max_samples', 'base_estimator__tol': 0.1, 'learning_rate': 0.5, 'n_estimators': 25}\n",
"Best accuracy: 0.9511777695988222\n"
]
}
],
"source": [ "source": [
"print(\"Best estimator: \", grid.best_estimator_)\n", "print(\"Best estimator: \", grid.best_estimator_)\n",
"print(\"Best hyperparameters: \", grid.best_params_)\n", "print(\"Best hyperparameters: \", grid.best_params_)\n",
@@ -354,7 +245,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.8.2" "version": "3.8.2-final"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -1,4 +1 @@
numpy scikit-learn>0.24
scikit-learn
pandas
ipympl

View File

@@ -15,6 +15,7 @@ from typing import Optional
import numpy as np import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.svm import SVC, LinearSVC from sklearn.svm import SVC, LinearSVC
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_consistent_length from sklearn.utils import check_consistent_length
from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.multiclass import check_classification_targets
from sklearn.exceptions import ConvergenceWarning from sklearn.exceptions import ConvergenceWarning
@@ -41,6 +42,7 @@ class Snode:
impurity: float, impurity: float,
title: str, title: str,
weight: np.ndarray = None, weight: np.ndarray = None,
scaler: StandardScaler = None,
): ):
self._clf = clf self._clf = clf
self._title = title self._title = title
@@ -58,6 +60,7 @@ class Snode:
self._features = features self._features = features
self._impurity = impurity self._impurity = impurity
self._partition_column: int = -1 self._partition_column: int = -1
self._scaler = scaler
@classmethod @classmethod
def copy(cls, node: "Snode") -> "Snode": def copy(cls, node: "Snode") -> "Snode":
@@ -68,6 +71,8 @@ class Snode:
node._features, node._features,
node._impurity, node._impurity,
node._title, node._title,
node._sample_weight,
node._scaler,
) )
def set_partition_column(self, col: int): def set_partition_column(self, col: int):
@@ -79,6 +84,30 @@ class Snode:
def set_down(self, son): def set_down(self, son):
self._down = son self._down = son
def set_title(self, title):
self._title = title
def set_classifier(self, clf):
self._clf = clf
def set_features(self, features):
self._features = features
def set_impurity(self, impurity):
self._impurity = impurity
def get_title(self) -> str:
return self._title
def get_classifier(self) -> SVC:
return self._clf
def get_impurity(self) -> float:
return self._impurity
def get_features(self) -> np.array:
return self._features
def set_up(self, son): def set_up(self, son):
self._up = son self._up = son
@@ -154,6 +183,7 @@ class Splitter:
criteria: str = None, criteria: str = None,
min_samples_split: int = None, min_samples_split: int = None,
random_state=None, random_state=None,
normalize=False,
): ):
self._clf = clf self._clf = clf
self._random_state = random_state self._random_state = random_state
@@ -163,6 +193,7 @@ class Splitter:
self._min_samples_split = min_samples_split self._min_samples_split = min_samples_split
self._criteria = criteria self._criteria = criteria
self._splitter_type = splitter_type self._splitter_type = splitter_type
self._normalize = normalize
if clf is None: if clf is None:
raise ValueError(f"clf has to be a sklearn estimator, got({clf})") raise ValueError(f"clf has to be a sklearn estimator, got({clf})")
@@ -462,8 +493,7 @@ class Splitter:
origin[down] if any(down) else None, origin[down] if any(down) else None,
] ]
@staticmethod def _distances(self, node: Snode, data: np.ndarray) -> np.array:
def _distances(node: Snode, data: np.ndarray) -> np.array:
"""Compute distances of the samples to the hyperplane of the node """Compute distances of the samples to the hyperplane of the node
Parameters Parameters
@@ -479,7 +509,10 @@ class Splitter:
array of shape (m, nc) with the distances of every sample to array of shape (m, nc) with the distances of every sample to
the hyperplane of every class. nc = # of classes the hyperplane of every class. nc = # of classes
""" """
return node._clf.decision_function(data[:, node._features]) X_transformed = data[:, node._features]
if self._normalize:
X_transformed = node._scaler.transform(X_transformed)
return node._clf.decision_function(X_transformed)
class Stree(BaseEstimator, ClassifierMixin): class Stree(BaseEstimator, ClassifierMixin):
@@ -505,6 +538,7 @@ class Stree(BaseEstimator, ClassifierMixin):
min_samples_split: int = 0, min_samples_split: int = 0,
max_features=None, max_features=None,
splitter: str = "random", splitter: str = "random",
normalize: bool = False,
): ):
self.max_iter = max_iter self.max_iter = max_iter
self.C = C self.C = C
@@ -519,6 +553,7 @@ class Stree(BaseEstimator, ClassifierMixin):
self.max_features = max_features self.max_features = max_features
self.criterion = criterion self.criterion = criterion
self.splitter = splitter self.splitter = splitter
self.normalize = normalize
def _more_tags(self) -> dict: def _more_tags(self) -> dict:
"""Required by sklearn to supply features of the classifier """Required by sklearn to supply features of the classifier
@@ -582,6 +617,7 @@ class Stree(BaseEstimator, ClassifierMixin):
criteria=self.split_criteria, criteria=self.split_criteria,
random_state=self.random_state, random_state=self.random_state,
min_samples_split=self.min_samples_split, min_samples_split=self.min_samples_split,
normalize=self.normalize,
) )
if self.random_state is not None: if self.random_state is not None:
random.seed(self.random_state) random.seed(self.random_state)
@@ -635,41 +671,39 @@ class Stree(BaseEstimator, ClassifierMixin):
X = X[~indices_zero, :] X = X[~indices_zero, :]
y = y[~indices_zero] y = y[~indices_zero]
sample_weight = sample_weight[~indices_zero] sample_weight = sample_weight[~indices_zero]
self.depth_ = max(depth, self.depth_)
scaler = StandardScaler()
node = Snode(None, X, y, X.shape[1], 0.0, title, sample_weight, scaler)
if np.unique(y).shape[0] == 1: if np.unique(y).shape[0] == 1:
# only 1 class => pure dataset # only 1 class => pure dataset
return Snode( node.set_title(title + ", <pure>")
clf=None, return node
X=X,
y=y,
features=X.shape[1],
impurity=0.0,
title=title + ", <pure>",
weight=sample_weight,
)
# Train the model # Train the model
clf = self._build_clf() clf = self._build_clf()
Xs, features = self.splitter_.get_subspace(X, y, self.max_features_) Xs, features = self.splitter_.get_subspace(X, y, self.max_features_)
if self.normalize:
scaler.fit(Xs)
Xs = scaler.transform(Xs)
clf.fit(Xs, y, sample_weight=sample_weight) clf.fit(Xs, y, sample_weight=sample_weight)
impurity = self.splitter_.partition_impurity(y) node.set_impurity(self.splitter_.partition_impurity(y))
node = Snode(clf, X, y, features, impurity, title, sample_weight) node.set_classifier(clf)
self.depth_ = max(depth, self.depth_) node.set_features(features)
self.splitter_.partition(X, node, True) self.splitter_.partition(X, node, True)
X_U, X_D = self.splitter_.part(X) X_U, X_D = self.splitter_.part(X)
y_u, y_d = self.splitter_.part(y) y_u, y_d = self.splitter_.part(y)
sw_u, sw_d = self.splitter_.part(sample_weight) sw_u, sw_d = self.splitter_.part(sample_weight)
if X_U is None or X_D is None: if X_U is None or X_D is None:
# didn't part anything # didn't part anything
return Snode( node.set_title(title + ", <cgaf>")
clf, return node
X, node.set_up(
y, self.train(X_U, y_u, sw_u, depth + 1, title + f" - Up({depth+1})")
features=X.shape[1], )
impurity=impurity, node.set_down(
title=title + ", <cgaf>", self.train(
weight=sample_weight, X_D, y_d, sw_d, depth + 1, title + f" - Down({depth+1})"
)
) )
node.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + " - Up"))
node.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + " - Down"))
return node return node
def _build_predictor(self): def _build_predictor(self):
@@ -812,6 +846,22 @@ class Stree(BaseEstimator, ClassifierMixin):
score = y_true == y_pred score = y_true == y_pred
return _weighted_sum(score, sample_weight, normalize=True) return _weighted_sum(score, sample_weight, normalize=True)
def nodes_leaves(self) -> tuple:
"""Compute the number of nodes and leaves in the built tree
Returns
-------
[tuple]
tuple with the number of nodes and the number of leaves
"""
nodes = 0
leaves = 0
for node in self:
nodes += 1
if node.is_leaf():
leaves += 1
return nodes, leaves
def __iter__(self) -> Siterator: def __iter__(self) -> Siterator:
"""Create an iterator to be able to visit the nodes of the tree in """Create an iterator to be able to visit the nodes of the tree in
preorder, can make a list with all the nodes in preorder preorder, can make a list with all the nodes in preorder

View File

@@ -1,8 +1,6 @@
import os import os
import unittest import unittest
import numpy as np import numpy as np
from stree import Stree, Snode from stree import Stree, Snode
from .utils import load_dataset from .utils import load_dataset
@@ -69,6 +67,31 @@ class Snode_test(unittest.TestCase):
self.assertEqual(0.75, test._belief) self.assertEqual(0.75, test._belief)
self.assertEqual(-1, test._partition_column) self.assertEqual(-1, test._partition_column)
def test_set_title(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
self.assertEqual("test", test.get_title())
test.set_title("another")
self.assertEqual("another", test.get_title())
def test_set_classifier(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
clf = Stree()
self.assertIsNone(test.get_classifier())
test.set_classifier(clf)
self.assertEqual(clf, test.get_classifier())
def test_set_impurity(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
self.assertEqual(0.0, test.get_impurity())
test.set_impurity(54.7)
self.assertEqual(54.7, test.get_impurity())
def test_set_features(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [0, 1], 0.0, "test")
self.assertListEqual([0, 1], test.get_features())
test.set_features([1, 2])
self.assertListEqual([1, 2], test.get_features())
def test_make_predictor_on_not_leaf(self): def test_make_predictor_on_not_leaf(self):
test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test") test = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test")
test.set_up(Snode(None, [1], [1], [], 0.0, "another_test")) test.set_up(Snode(None, [1], [1], [], 0.0, "another_test"))
@@ -94,3 +117,5 @@ class Snode_test(unittest.TestCase):
self.assertEqual("test", computed._title) self.assertEqual("test", computed._title)
self.assertIsInstance(computed._clf, Stree) self.assertIsInstance(computed._clf, Stree)
self.assertEqual(test._partition_column, computed._partition_column) self.assertEqual(test._partition_column, computed._partition_column)
self.assertEqual(test._sample_weight, computed._sample_weight)
self.assertEqual(test._scaler, computed._scaler)

View File

@@ -103,20 +103,20 @@ class Stree_test(unittest.TestCase):
def test_iterator_and_str(self): def test_iterator_and_str(self):
"""Check preorder iterator""" """Check preorder iterator"""
expected = [ expected = [
"root feaures=(0, 1, 2) impurity=1.0000 counts=(array([0, 1]), arr" "root feaures=(0, 1, 2) impurity=1.0000 counts=(array([0, 1]), "
"ay([750, 750]))", "array([750, 750]))",
"root - Down, <cgaf> - Leaf class=0 belief= 0.928297 impurity=0.37" "root - Down(2), <cgaf> - Leaf class=0 belief= 0.928297 impurity="
"22 counts=(array([0, 1]), array([725, 56]))", "0.3722 counts=(array([0, 1]), array([725, 56]))",
"root - Up feaures=(0, 1, 2) impurity=0.2178 counts=(array([0, 1])" "root - Up(2) feaures=(0, 1, 2) impurity=0.2178 counts=(array([0, "
", array([ 25, 694]))", "1]), array([ 25, 694]))",
"root - Up - Down feaures=(0, 1, 2) impurity=0.8454 counts=(array(" "root - Up(2) - Down(3) feaures=(0, 1, 2) impurity=0.8454 counts="
"[0, 1]), array([8, 3]))", "(array([0, 1]), array([8, 3]))",
"root - Up - Down - Down, <pure> - Leaf class=0 belief= 1.000000 i" "root - Up(2) - Down(3) - Down(4), <pure> - Leaf class=0 belief= "
"mpurity=0.0000 counts=(array([0]), array([7]))", "1.000000 impurity=0.0000 counts=(array([0]), array([7]))",
"root - Up - Down - Up, <cgaf> - Leaf class=1 belief= 0.750000 imp" "root - Up(2) - Down(3) - Up(4), <cgaf> - Leaf class=1 belief= "
"urity=0.8113 counts=(array([0, 1]), array([1, 3]))", "0.750000 impurity=0.8113 counts=(array([0, 1]), array([1, 3]))",
"root - Up - Up, <cgaf> - Leaf class=1 belief= 0.975989 impurity=0" "root - Up(2) - Up(3), <cgaf> - Leaf class=1 belief= 0.975989 "
".1634 counts=(array([0, 1]), array([ 17, 691]))", "impurity=0.1634 counts=(array([0, 1]), array([ 17, 691]))",
] ]
computed = [] computed = []
expected_string = "" expected_string = ""
@@ -198,10 +198,10 @@ class Stree_test(unittest.TestCase):
"Synt": { "Synt": {
"max_samples linear": 0.9606666666666667, "max_samples linear": 0.9606666666666667,
"max_samples rbf": 0.7133333333333334, "max_samples rbf": 0.7133333333333334,
"max_samples poly": 0.49066666666666664, "max_samples poly": 0.618,
"impurity linear": 0.9606666666666667, "impurity linear": 0.9606666666666667,
"impurity rbf": 0.7133333333333334, "impurity rbf": 0.7133333333333334,
"impurity poly": 0.49066666666666664, "impurity poly": 0.618,
}, },
"Iris": { "Iris": {
"max_samples linear": 1.0, "max_samples linear": 1.0,
@@ -378,9 +378,14 @@ class Stree_test(unittest.TestCase):
n_samples=500, n_samples=500,
) )
clf = Stree(kernel="rbf", random_state=self._random_state) clf = Stree(kernel="rbf", random_state=self._random_state)
self.assertEqual(0.824, clf.fit(X, y).score(X, y)) clf2 = Stree(
kernel="rbf", random_state=self._random_state, normalize=True
)
self.assertEqual(0.768, clf.fit(X, y).score(X, y))
self.assertEqual(0.814, clf2.fit(X, y).score(X, y))
X, y = load_wine(return_X_y=True) X, y = load_wine(return_X_y=True)
self.assertEqual(0.6741573033707865, clf.fit(X, y).score(X, y)) self.assertEqual(0.6741573033707865, clf.fit(X, y).score(X, y))
self.assertEqual(1.0, clf2.fit(X, y).score(X, y))
def test_score_multiclass_poly(self): def test_score_multiclass_poly(self):
X, y = load_dataset( X, y = load_dataset(
@@ -392,9 +397,16 @@ class Stree_test(unittest.TestCase):
clf = Stree( clf = Stree(
kernel="poly", random_state=self._random_state, C=10, degree=5 kernel="poly", random_state=self._random_state, C=10, degree=5
) )
clf2 = Stree(
kernel="poly",
random_state=self._random_state,
normalize=True,
)
self.assertEqual(0.786, clf.fit(X, y).score(X, y)) self.assertEqual(0.786, clf.fit(X, y).score(X, y))
self.assertEqual(0.818, clf2.fit(X, y).score(X, y))
X, y = load_wine(return_X_y=True) X, y = load_wine(return_X_y=True)
self.assertEqual(0.702247191011236, clf.fit(X, y).score(X, y)) self.assertEqual(0.702247191011236, clf.fit(X, y).score(X, y))
self.assertEqual(0.6067415730337079, clf2.fit(X, y).score(X, y))
def test_score_multiclass_linear(self): def test_score_multiclass_linear(self):
X, y = load_dataset( X, y = load_dataset(
@@ -405,8 +417,14 @@ class Stree_test(unittest.TestCase):
) )
clf = Stree(kernel="linear", random_state=self._random_state) clf = Stree(kernel="linear", random_state=self._random_state)
self.assertEqual(0.9533333333333334, clf.fit(X, y).score(X, y)) self.assertEqual(0.9533333333333334, clf.fit(X, y).score(X, y))
# Check with context based standardization
clf2 = Stree(
kernel="linear", random_state=self._random_state, normalize=True
)
self.assertEqual(0.9526666666666667, clf2.fit(X, y).score(X, y))
X, y = load_wine(return_X_y=True) X, y = load_wine(return_X_y=True)
self.assertEqual(0.9550561797752809, clf.fit(X, y).score(X, y)) self.assertEqual(0.9831460674157303, clf.fit(X, y).score(X, y))
self.assertEqual(1.0, clf2.fit(X, y).score(X, y))
def test_zero_all_sample_weights(self): def test_zero_all_sample_weights(self):
X, y = load_dataset(self._random_state) X, y = load_dataset(self._random_state)
@@ -439,3 +457,55 @@ class Stree_test(unittest.TestCase):
self.assertEqual(model1.score(X, y), 1) self.assertEqual(model1.score(X, y), 1)
self.assertAlmostEqual(model2.score(X, y), 0.66666667) self.assertAlmostEqual(model2.score(X, y), 0.66666667)
self.assertEqual(model2.score(X, y, w), 1) self.assertEqual(model2.score(X, y, w), 1)
def test_depth(self):
X, y = load_dataset(
random_state=self._random_state,
n_classes=3,
n_features=5,
n_samples=1500,
)
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
self.assertEqual(6, clf.depth_)
X, y = load_wine(return_X_y=True)
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
self.assertEqual(4, clf.depth_)
def test_nodes_leaves(self):
X, y = load_dataset(
random_state=self._random_state,
n_classes=3,
n_features=5,
n_samples=1500,
)
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
nodes, leaves = clf.nodes_leaves()
self.assertEqual(25, nodes)
self.assertEquals(13, leaves)
X, y = load_wine(return_X_y=True)
clf = Stree(random_state=self._random_state)
clf.fit(X, y)
nodes, leaves = clf.nodes_leaves()
self.assertEqual(9, nodes)
self.assertEquals(5, leaves)
def test_nodes_leaves_artificial(self):
n1 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test1")
n2 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test2")
n3 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test3")
n4 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test4")
n5 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test5")
n6 = Snode(None, [1, 2, 3, 4], [1, 0, 1, 1], [], 0.0, "test6")
n1.set_up(n2)
n2.set_up(n3)
n2.set_down(n4)
n3.set_up(n5)
n4.set_down(n6)
clf = Stree(random_state=self._random_state)
clf.tree_ = n1
nodes, leaves = clf.nodes_leaves()
self.assertEqual(6, nodes)
self.assertEqual(2, leaves)