diff --git a/data/.gitignore b/data/.gitignore deleted file mode 100644 index f59ec20..0000000 --- a/data/.gitignore +++ /dev/null @@ -1 +0,0 @@ -* \ No newline at end of file diff --git a/notebooks/gridsearch.ipynb b/notebooks/gridsearch.ipynb index 182a31d..adc4978 100644 --- a/notebooks/gridsearch.ipynb +++ b/notebooks/gridsearch.ipynb @@ -8,7 +8,8 @@ "source": [ "from sklearn.ensemble import AdaBoostClassifier\n", "from sklearn.tree import DecisionTreeClassifier\n", - "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.svm import LinearSVC\n", + "from sklearn.model_selection import GridSearchCV, train_test_split\n", "from sklearn.datasets import load_iris\n", "from stree import Stree" ] @@ -27,12 +28,51 @@ "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "Fraud: 0.244% 196\nValid: 99.755% 80234\nX.shape (1196, 28) y.shape (1196,)\nFraud: 16.472% 197\nValid: 83.528% 999\n" + } + ], "source": [ - "import pandas as pd\n", - "df = pd.read_csv('data/creditcard.csv')\n", - "y = df.Class.values\n", - "X = df.drop(['Class', 'Time', 'Amount'], axis=1).values" + "random_state=1\n", + "\n", + "def load_creditcard(n_examples=0):\n", + " import pandas as pd\n", + " import numpy as np\n", + " import random\n", + " df = pd.read_csv('data/creditcard.csv')\n", + " print(\"Fraud: {0:.3f}% {1}\".format(df.Class[df.Class == 1].count()*100/df.shape[0], df.Class[df.Class == 1].count()))\n", + " print(\"Valid: {0:.3f}% {1}\".format(df.Class[df.Class == 0].count()*100/df.shape[0], df.Class[df.Class == 0].count()))\n", + " y = df.Class\n", + " X = df.drop(['Class', 'Time', 'Amount'], axis=1).values\n", + " if n_examples > 0:\n", + " # Take first n_examples samples\n", + " X = X[:n_examples, :]\n", + " y = y[:n_examples, :]\n", + " else:\n", + " # Take all the positive samples with a number of random negatives\n", + " if n_examples < 0:\n", + " Xt = X[(y == 1).ravel()]\n", + " yt = y[(y == 1).ravel()]\n", + " indices = random.sample(range(X.shape[0]), -1 * n_examples)\n", + " X = np.append(Xt, X[indices], axis=0)\n", + " y = np.append(yt, y[indices], axis=0)\n", + " print(\"X.shape\", X.shape, \" y.shape\", y.shape)\n", + " print(\"Fraud: {0:.3f}% {1}\".format(len(y[y == 1])*100/X.shape[0], len(y[y == 1])))\n", + " print(\"Valid: {0:.3f}% {1}\".format(len(y[y == 0]) * 100 / X.shape[0], len(y[y == 0])))\n", + " Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, train_size=0.7, shuffle=True, random_state=random_state, stratify=y)\n", + " return Xtrain, Xtest, ytrain, ytest\n", + "\n", + "# data = load_creditcard(-5000) # Take all true samples + 5000 of the others\n", + "# data = load_creditcard(5000) # Take the first 5000 samples\n", + "data = load_creditcard(-1000) # Take all the samples\n", + "\n", + "Xtrain = data[0]\n", + "Xtest = data[1]\n", + "ytrain = data[2]\n", + "ytest = data[3]" ] }, { @@ -43,14 +83,12 @@ { "output_type": "stream", "name": "stdout", - "text": "\n" + "text": "root\nroot - Down - Leaf class=1.0 belief=0.976000 counts=(array([0., 1.]), array([ 3, 122]))\nroot - Up - Leaf class=0.0 belief=0.977528 counts=(array([0., 1.]), array([696, 16]))\n\n" } ], "source": [ - "c = Stree(C=17, max_depth=2)\n", - "print(c)\n", - "c.fit(X, y)\n", - "print(len(list(c)))\n", + "c = Stree(max_depth=2)\n", + "c.fit(Xtrain, ytrain)\n", "print(c)" ] }, @@ -62,7 +100,7 @@ "source": [ "#'base_estimator': [DecisionTreeClassifier(max_depth=1), Stree(max_depth=2), Stree(max_depth=3)],\n", "parameters = {\n", - " 'base_estimator': [Stree(max_depth=2), Stree(max_depth=3)],\n", + " 'base_estimator': [LinearSVC(), Stree(max_depth=2), Stree(max_depth=3)],\n", " 'n_estimators': [20, 50, 100, 150],\n", " 'learning_rate': [.5, 1, 1.5] \n", "}" @@ -89,41 +127,58 @@ { "output_type": "stream", "name": "stdout", - "text": "Fitting 5 folds for each of 24 candidates, totalling 120 fits\n[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n[Parallel(n_jobs=-1)]: Done 2 tasks | elapsed: 2.5s\n[Parallel(n_jobs=-1)]: Done 9 tasks | elapsed: 2.6s\n[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 2.6s\n[Parallel(n_jobs=-1)]: Done 25 tasks | elapsed: 2.6s\n[Parallel(n_jobs=-1)]: Batch computation too fast (0.1837s.) Setting batch_size=2.\n[Parallel(n_jobs=-1)]: Done 34 tasks | elapsed: 2.7s\n[Parallel(n_jobs=-1)]: Done 45 tasks | elapsed: 2.7s\n[Parallel(n_jobs=-1)]: Batch computation too fast (0.0313s.) Setting batch_size=4.\n[Parallel(n_jobs=-1)]: Done 64 tasks | elapsed: 2.7s\n[Parallel(n_jobs=-1)]: Batch computation too fast (0.0302s.) Setting batch_size=8.\n[Parallel(n_jobs=-1)]: Done 92 out of 120 | elapsed: 2.7s remaining: 0.8s\n[Parallel(n_jobs=-1)]: Done 118 out of 120 | elapsed: 2.7s remaining: 0.0s\n[Parallel(n_jobs=-1)]: Done 120 out of 120 | elapsed: 2.7s finished\n" - }, - { - "output_type": "error", - "ename": "ValueError", - "evalue": "Stree doesn't support sample_weight.", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mclf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mAdaBoostClassifier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrandom_state\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrandom_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mgrid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGridSearchCV\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_train_score\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mgrid\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/.virtualenvs/general/lib/python3.7/site-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 71\u001b[0m FutureWarning)\n\u001b[1;32m 72\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.virtualenvs/general/lib/python3.7/site-packages/sklearn/model_selection/_search.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, groups, **fit_params)\u001b[0m\n\u001b[1;32m 763\u001b[0m \u001b[0mrefit_start_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 764\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0my\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 765\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbest_estimator_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 766\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbest_estimator_\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.virtualenvs/general/lib/python3.7/site-packages/sklearn/ensemble/_weight_boosting.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 441\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[0;31m# Fit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 443\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 445\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_validate_estimator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.virtualenvs/general/lib/python3.7/site-packages/sklearn/ensemble/_weight_boosting.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;31m# Check parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_validate_estimator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;31m# Clear any previous fit results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/.virtualenvs/general/lib/python3.7/site-packages/sklearn/ensemble/_weight_boosting.py\u001b[0m in \u001b[0;36m_validate_estimator\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 459\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_fit_parameter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_estimator_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"sample_weight\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 460\u001b[0m raise ValueError(\"%s doesn't support sample_weight.\"\n\u001b[0;32m--> 461\u001b[0;31m % self.base_estimator_.__class__.__name__)\n\u001b[0m\u001b[1;32m 462\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 463\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_boost\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miboost\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrandom_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: Stree doesn't support sample_weight." - ] + "text": "(X: numpy.ndarray, y: numpy.ndarray, sample_weight: = None) -> 'Stree'\n" } ], "source": [ - "random_state=2020\n", - "clf = AdaBoostClassifier(random_state=random_state)\n", - "grid = GridSearchCV(clf, parameters, verbose=10, n_jobs=-1, return_train_score=True)\n", - "grid.fit(X, y)" + "from inspect import signature\n", + "print(signature(c.fit))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.utils.validation import _check_sample_weight" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", - "text": "AdaBoostClassifier(base_estimator=Stree(max_depth=2), learning_rate=0.5,\n n_estimators=20, random_state=2020)\n" + "text": "Fitting 5 folds for each of 36 candidates, totalling 180 fits\n[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n[Parallel(n_jobs=-1)]: Done 2 tasks | elapsed: 1.3s\n[Parallel(n_jobs=-1)]: Done 9 tasks | elapsed: 1.3s\n[Parallel(n_jobs=-1)]: Done 16 tasks | elapsed: 1.3s\n[Parallel(n_jobs=-1)]: Batch computation too fast (0.1671s.) Setting batch_size=2.\n[Parallel(n_jobs=-1)]: Done 25 tasks | elapsed: 1.3s\n[Parallel(n_jobs=-1)]: Done 34 tasks | elapsed: 1.4s\n[Parallel(n_jobs=-1)]: Batch computation too fast (0.0413s.) Setting batch_size=4.\n[Parallel(n_jobs=-1)]: Done 50 tasks | elapsed: 1.4s\n[Parallel(n_jobs=-1)]: Batch computation too slow (7.7880s.) Setting batch_size=1.\n[Parallel(n_jobs=-1)]: Done 74 tasks | elapsed: 9.2s\n[Parallel(n_jobs=-1)]: Done 121 tasks | elapsed: 48.9s\n[Parallel(n_jobs=-1)]: Done 140 tasks | elapsed: 1.0min\n[Parallel(n_jobs=-1)]: Done 161 tasks | elapsed: 1.3min\n[Parallel(n_jobs=-1)]: Done 180 out of 180 | elapsed: 1.6min finished\n" + }, + { + "output_type": "execute_result", + "data": { + "text/plain": "GridSearchCV(estimator=AdaBoostClassifier(random_state=2020), n_jobs=-1,\n param_grid={'base_estimator': [LinearSVC(), Stree(max_depth=2),\n Stree(max_depth=3)],\n 'learning_rate': [0.5, 1, 1.5],\n 'n_estimators': [20, 50, 100, 150]},\n return_train_score=True, verbose=10)" + }, + "metadata": {}, + "execution_count": 9 + } + ], + "source": [ + "random_state=2020\n", + "clf = AdaBoostClassifier(random_state=random_state)\n", + "grid = GridSearchCV(clf, parameters, verbose=10, n_jobs=-1, return_train_score=True)\n", + "grid.fit(Xtrain, ytrain)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": "AdaBoostClassifier(base_estimator=Stree(max_depth=2), learning_rate=0.5,\n n_estimators=150, random_state=2020)\n" } ], "source": [ @@ -131,11 +186,12 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "AdaBoostClassifier(base_estimator=Stree(max_depth=3), learning_rate=0.5,\n", + " n_estimators=20, random_state=2020)" + ] } ], "metadata": { diff --git a/notebooks/test2.ipynb b/notebooks/test2.ipynb index 4865b64..3d5a8c6 100644 --- a/notebooks/test2.ipynb +++ b/notebooks/test2.ipynb @@ -48,7 +48,7 @@ { "output_type": "stream", "name": "stdout", - "text": "Fraud: 0.173% 492\nValid: 99.827% 284315\nX.shape (1492, 28) y.shape (1492,)\nFraud: 33.043% 493\nValid: 66.957% 999\n" + "text": "Fraud: 0.244% 196\nValid: 99.755% 80234\nX.shape (1196, 28) y.shape (1196,)\nFraud: 16.722% 200\nValid: 83.278% 996\n" } ], "source": [ @@ -103,7 +103,7 @@ { "output_type": "stream", "name": "stdout", - "text": "depth: 1\ndepth: 2\ndepth: 2\n************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9550\nClassifier's accuracy (test) : 0.9598\nroot\nroot - Down, - Leaf class=1 belief=0.983766 counts=(array([0, 1]), array([ 5, 303]))\nroot - Up, - Leaf class=0 belief=0.942935 counts=(array([0, 1]), array([694, 42]))\n\n**************************************************\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 3\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9569\nClassifier's accuracy (test) : 0.9598\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=0.990196 counts=(array([0, 1]), array([ 3, 303]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([2]))\nroot - Up, - Leaf class=0 belief=0.942935 counts=(array([0, 1]), array([694, 42]))\n\n**************************************************\ndepth: 1\ndepth: 2\ndepth: 3\ndepth: 4\ndepth: 3\ndepth: 2\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9684\nClassifier's accuracy (test) : 0.9688\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([310]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([4]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up, - Leaf class=0 belief=0.954608 counts=(array([0, 1]), array([694, 33]))\n\n**************************************************\ndepth: 1\ndepth: 2\ndepth: 2\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9693\nClassifier's accuracy (test) : 0.9710\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([313]))\nroot - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([7]))\nroot - Up, - Leaf class=0 belief=0.955801 counts=(array([0, 1]), array([692, 32]))\n\n**************************************************\ndepth: 1\ndepth: 2\ndepth: 3\ndepth: 4\ndepth: 5\ndepth: 6\ndepth: 6\ndepth: 5\ndepth: 4\ndepth: 3\ndepth: 2\ndepth: 3\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9818\nClassifier's accuracy (test) : 0.9554\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([307]))\nroot - Down - Up\nroot - Down - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([8]))\nroot - Down - Up - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([25]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([5]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up - Up, - Leaf class=0 belief=0.972263 counts=(array([0, 1]), array([666, 19]))\n\n**************************************************\n0.6576 secs\n" + "text": "************** C=0.001 ****************************\nClassifier's accuracy (train): 0.9797\nClassifier's accuracy (test) : 0.9749\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1.0 belief=0.984127 counts=(array([0., 1.]), array([ 2, 124]))\nroot - Down - Up, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([5]))\nroot - Up\nroot - Up - Down, - Leaf class=0.0 belief=0.750000 counts=(array([0., 1.]), array([3, 1]))\nroot - Up - Up\nroot - Up - Up - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([1]))\nroot - Up - Up - Up, - Leaf class=0.0 belief=0.980029 counts=(array([0., 1.]), array([687, 14]))\n\n**************************************************\n************** C=0.01 ****************************\nClassifier's accuracy (train): 0.9809\nClassifier's accuracy (test) : 0.9749\nroot\nroot - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([124]))\nroot - Up, - Leaf class=0.0 belief=0.977560 counts=(array([0., 1.]), array([697, 16]))\n\n**************************************************\n************** C=1 ****************************\nClassifier's accuracy (train): 0.9869\nClassifier's accuracy (test) : 0.9749\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([129]))\nroot - Down - Up, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([2]))\nroot - Up, - Leaf class=0.0 belief=0.984419 counts=(array([0., 1.]), array([695, 11]))\n\n**************************************************\n************** C=5 ****************************\nClassifier's accuracy (train): 0.9869\nClassifier's accuracy (test) : 0.9777\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([129]))\nroot - Down - Up, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([2]))\nroot - Up, - Leaf class=0.0 belief=0.984419 counts=(array([0., 1.]), array([695, 11]))\n\n**************************************************\n************** C=17 ****************************\nClassifier's accuracy (train): 0.9916\nClassifier's accuracy (test) : 0.9833\nroot\nroot - Down\nroot - Down - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([131]))\nroot - Down - Up, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([8]))\nroot - Up\nroot - Up - Down, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([1]))\nroot - Up - Up - Down - Up, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([5]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([1]))\nroot - Up - Up - Up - Up, - Leaf class=0.0 belief=0.989855 counts=(array([0., 1.]), array([683, 7]))\n\n**************************************************\n0.2235 secs\n" } ], "source": [ @@ -144,7 +144,7 @@ { "output_type": "stream", "name": "stdout", - "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([307]))\nroot - Down - Up\nroot - Down - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([8]))\nroot - Down - Up - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([25]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([5]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up - Up, - Leaf class=0 belief=0.972263 counts=(array([0, 1]), array([666, 19]))\n" + "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([131]))\nroot - Down - Up, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([8]))\nroot - Up\nroot - Up - Down, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([1]))\nroot - Up - Up - Down - Up, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([5]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([1]))\nroot - Up - Up - Up - Up, - Leaf class=0.0 belief=0.989855 counts=(array([0., 1.]), array([683, 7]))\n" } ], "source": [ @@ -161,7 +161,7 @@ { "output_type": "stream", "name": "stdout", - "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([307]))\nroot - Down - Up\nroot - Down - Up - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([8]))\nroot - Down - Up - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([25]))\nroot - Up\nroot - Up - Down\nroot - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([3]))\nroot - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([2]))\nroot - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down\nroot - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([5]))\nroot - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([5]))\nroot - Up - Up - Up - Up\nroot - Up - Up - Up - Up - Down\nroot - Up - Up - Up - Up - Down - Down, - Leaf class=1 belief=1.000000 counts=(array([1]), array([1]))\nroot - Up - Up - Up - Up - Down - Up, - Leaf class=0 belief=1.000000 counts=(array([0]), array([1]))\nroot - Up - Up - Up - Up - Up, - Leaf class=0 belief=0.972263 counts=(array([0, 1]), array([666, 19]))\n" + "text": "root\nroot - Down\nroot - Down - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([131]))\nroot - Down - Up, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([8]))\nroot - Up\nroot - Up - Down, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([1]))\nroot - Up - Up\nroot - Up - Up - Down\nroot - Up - Up - Down - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([1]))\nroot - Up - Up - Down - Up, - Leaf class=0.0 belief=1.000000 counts=(array([0.]), array([5]))\nroot - Up - Up - Up\nroot - Up - Up - Up - Down, - Leaf class=1.0 belief=1.000000 counts=(array([1.]), array([1]))\nroot - Up - Up - Up - Up, - Leaf class=0.0 belief=0.989855 counts=(array([0., 1.]), array([683, 7]))\n" } ], "source": [ @@ -174,13 +174,7 @@ "cell_type": "code", "execution_count": 9, "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": "depth: 1\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 1\ndepth: 1\ndepth: 1\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 4\ndepth: 4\ndepth: 5\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 1\ndepth: 1\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 1\ndepth: 1\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 4\ndepth: 4\ndepth: 5\ndepth: 1\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 4\ndepth: 4\ndepth: 5\ndepth: 1\ndepth: 2\ndepth: 3\ndepth: 1\ndepth: 1\ndepth: 1\ndepth: 2\ndepth: 1\ndepth: 1\ndepth: 1\ndepth: 1\n" - } - ], + "outputs": [], "source": [ "# Check if the classifier is a sklearn estimator\n", "from sklearn.utils.estimator_checks import check_estimator\n", @@ -195,7 +189,7 @@ { "output_type": "stream", "name": "stdout", - "text": "1 functools.partial(, 'Stree')\n2 functools.partial(, 'Stree')\ndepth: 1\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 1\ndepth: 1\n3 functools.partial(, 'Stree')\ndepth: 1\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 4\ndepth: 4\ndepth: 5\n4 functools.partial(, 'Stree')\n5 functools.partial(, 'Stree')\n6 functools.partial(, 'Stree')\n7 functools.partial(, 'Stree')\n8 functools.partial(, 'Stree')\ndepth: 1\ndepth: 2\ndepth: 2\n9 functools.partial(, 'Stree', readonly_memmap=True)\ndepth: 1\ndepth: 2\ndepth: 2\n10 functools.partial(, 'Stree')\n11 functools.partial(, 'Stree')\ndepth: 1\n12 functools.partial(, 'Stree')\n13 functools.partial(, 'Stree')\ndepth: 1\ndepth: 1\n14 functools.partial(, 'Stree')\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\n15 functools.partial(, 'Stree')\ndepth: 1\ndepth: 2\ndepth: 2\n16 functools.partial(, 'Stree')\n17 functools.partial(, 'Stree')\ndepth: 1\n18 functools.partial(, 'Stree')\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\n19 functools.partial(, 'Stree')\n20 functools.partial(, 'Stree')\ndepth: 1\ndepth: 1\ndepth: 1\n21 functools.partial(, 'Stree')\n22 functools.partial(, 'Stree')\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\n23 functools.partial(, 'Stree', readonly_memmap=True)\ndepth: 1\ndepth: 2\ndepth: 2\ndepth: 1\ndepth: 2\ndepth: 2\n24 functools.partial(, 'Stree')\n25 functools.partial(, 'Stree')\n26 functools.partial(, 'Stree')\ndepth: 1\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 4\ndepth: 4\ndepth: 5\ndepth: 1\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 2\ndepth: 3\ndepth: 3\ndepth: 4\ndepth: 4\ndepth: 5\n27 functools.partial(, 'Stree')\n28 functools.partial(, 'Stree')\ndepth: 1\ndepth: 2\ndepth: 3\n29 functools.partial(, 'Stree')\n30 functools.partial(, 'Stree')\ndepth: 1\n31 functools.partial(, 'Stree')\ndepth: 1\n32 functools.partial(, 'Stree')\n33 functools.partial(, 'Stree')\ndepth: 1\ndepth: 2\n34 functools.partial(, 'Stree')\n35 functools.partial(, 'Stree')\n36 functools.partial(, 'Stree')\n37 functools.partial(, 'Stree')\ndepth: 1\n38 functools.partial(, 'Stree')\ndepth: 1\n39 functools.partial(, 'Stree')\ndepth: 1\ndepth: 1\n" + "text": "1 functools.partial(, 'Stree')\n2 functools.partial(, 'Stree')\n3 functools.partial(, 'Stree')\n4 functools.partial(, 'Stree')\n5 functools.partial(, 'Stree')\n6 functools.partial(, 'Stree')\n7 functools.partial(, 'Stree')\n8 functools.partial(, 'Stree')\n9 functools.partial(, 'Stree')\n10 functools.partial(, 'Stree', readonly_memmap=True)\n11 functools.partial(, 'Stree')\n12 functools.partial(, 'Stree')\n13 functools.partial(, 'Stree')\n14 functools.partial(, 'Stree')\n15 functools.partial(, 'Stree')\n16 functools.partial(, 'Stree')\n17 functools.partial(, 'Stree')\n18 functools.partial(, 'Stree')\n19 functools.partial(, 'Stree')\n20 functools.partial(, 'Stree')\n21 functools.partial(, 'Stree')\n22 functools.partial(, 'Stree')\n23 functools.partial(, 'Stree')\n24 functools.partial(, 'Stree', readonly_memmap=True)\n25 functools.partial(, 'Stree', readonly_memmap=True, X_dtype='float32')\n26 functools.partial(, 'Stree')\n27 functools.partial(, 'Stree')\n28 functools.partial(, 'Stree')\n29 functools.partial(, 'Stree')\n30 functools.partial(, 'Stree')\n31 functools.partial(, 'Stree')\n32 functools.partial(, 'Stree')\n33 functools.partial(, 'Stree')\n34 functools.partial(, 'Stree')\n35 functools.partial(, 'Stree')\n36 functools.partial(, 'Stree')\n37 functools.partial(, 'Stree')\n38 functools.partial(, 'Stree')\n39 functools.partial(, 'Stree')\n40 functools.partial(, 'Stree')\n41 functools.partial(, 'Stree')\n42 functools.partial(, 'Stree')\n" } ], "source": [ @@ -211,9 +205,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.7.6 64-bit ('general': venv)", "language": "python", - "name": "python3" + "name": "python37664bitgeneralvenvfbd0a23e74cf4e778460f5ffc6761f39" }, "language_info": { "codemirror_mode": { diff --git a/stree/Strees.py b/stree/Strees.py index 653a928..072dd62 100644 --- a/stree/Strees.py +++ b/stree/Strees.py @@ -13,7 +13,8 @@ import os import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.svm import LinearSVC -from sklearn.utils.validation import check_X_y, check_array, check_is_fitted +from sklearn.utils.multiclass import check_classification_targets +from sklearn.utils.validation import check_X_y, check_array, check_is_fitted, _check_sample_weight, check_random_state class Snode: @@ -102,9 +103,8 @@ class Siterator: class Stree(BaseEstimator, ClassifierMixin): """ """ - __folder = 'data/' - def __init__(self, C: float = 1.0, max_iter: int = 1000, random_state: int = 0, + def __init__(self, C: float = 1.0, max_iter: int = 1000, random_state: int = None, max_depth: int=None, tol: float=1e-4, use_predictions: bool = False): self.max_iter = max_iter self.C = C @@ -145,25 +145,25 @@ class Stree(BaseEstimator, ClassifierMixin): return origin[up[:, 0]] if any(up) else None, \ origin[down[:, 0]] if any(down) else None - def _split_data(self, node: Snode, data: np.ndarray, indices: np.ndarray) -> list: + def _distances(self, node: Snode, data: np.ndarray) -> np.array: if self.use_predictions: - yp = node._clf.predict(data) - down = (yp == 1).reshape(-1, 1) res = np.expand_dims(node._clf.decision_function(data), 1) else: # doesn't work with multiclass as each sample has to do inner product with its own coeficients # computes positition of every sample is w.r.t. the hyperplane res = self._linear_function(data, node) - down = res > 0 - data_up, data_down = self._split_array(data, down) - indices_up, indices_down = self._split_array(indices, down) - res_up, res_down = self._split_array(res, down) - return [data_up, indices_up, data_down, indices_down, res_up, res_down] + # data_up, data_down = self._split_array(data, down) + # indices_up, indices_down = self._split_array(indices, down) + # res_up, res_down = self._split_array(res, down) + # weight_up, weight_down = self._split_array(weights, down) + #return [data_up, indices_up, data_down, indices_down, weight_up, weight_down, res_up, res_down] + return res - def fit(self, X: np.ndarray, y: np.ndarray, weighted_samples: np.array=None, **fitparams: dict) -> 'Stree': - from sklearn.utils.multiclass import check_classification_targets - if fitparams is not None: - self.set_params(**fitparams) + def _split_criteria(self, data: np.array) -> np.array: + return data > 0 + + def fit(self, X: np.ndarray, y: np.ndarray, sample_weight: np.array = None) -> 'Stree': + # Check parameters are Ok. if type(y).__name__ == 'np.ndarray': y = y.ravel() if self.C < 0: @@ -173,12 +173,15 @@ class Stree(BaseEstimator, ClassifierMixin): raise ValueError(f"Maximum depth has to be greater than 1... got (max_depth={self.max_depth})") check_classification_targets(y) X, y = check_X_y(X, y) + sample_weight = _check_sample_weight(sample_weight, X) + check_classification_targets(y) + # Initialize computed parameters + #self.random_state = check_random_state(self.random_state) self.classes_ = np.unique(y) self.n_iter_ = self.max_iter self.depth_ = 0 - check_classification_targets(y) self.n_features_in_ = X.shape[1] - self.tree_ = self.train(X, y, 1, 'root') + self.tree_ = self.train(X, y, sample_weight, 1, 'root') self._build_predictor() return self @@ -195,7 +198,7 @@ class Stree(BaseEstimator, ClassifierMixin): run_tree(self.tree_) - def train(self, X: np.ndarray, y: np.ndarray, depth: int, title: str = 'root') -> Snode: + def train(self, X: np.ndarray, y: np.ndarray, sample_weight: np.ndarray, depth: int, title: str) -> Snode: if depth > self.__max_depth: return None @@ -203,21 +206,24 @@ class Stree(BaseEstimator, ClassifierMixin): # only 1 class => pure dataset return Snode(None, X, y, title + ', ') # Train the model - clf = LinearSVC(max_iter=self.max_iter, C=self.C, - random_state=self.random_state) - clf.fit(X, y) + clf = LinearSVC(max_iter=self.max_iter, random_state=self.random_state, + C=self.C) #, sample_weight=sample_weight) + clf.fit(X, y, sample_weight=sample_weight) tree = Snode(clf, X, y, title) self.depth_ = max(depth, self.depth_) - X_U, y_u, X_D, y_d, _, _ = self._split_data(tree, X, y) + down = self._split_criteria(self._distances(tree, X)) + X_U, X_D = self._split_array(X, down) + y_u, y_d = self._split_array(y, down) + sw_u, sw_d = self._split_array(sample_weight, down) if X_U is None or X_D is None: # didn't part anything return Snode(clf, X, y, title + ', ') - tree.set_up(self.train(X_U, y_u, depth + 1, title + ' - Up')) - tree.set_down(self.train(X_D, y_d, depth + 1, title + ' - Down')) + tree.set_up(self.train(X_U, y_u, sw_u, depth + 1, title + ' - Up')) + tree.set_down(self.train(X_D, y_d, sw_d, depth + 1, title + ' - Down')) return tree - def _reorder_results(self, y: np.array, indices: np.array, proba=False) -> np.array: - if proba: + def _reorder_results(self, y: np.array, indices: np.array) -> np.array: + if y.ndim > 1 and y.shape[1] > 1: # if predict_proba return np.array of floats y_ordered = np.zeros(y.shape, dtype=float) else: @@ -236,10 +242,12 @@ class Stree(BaseEstimator, ClassifierMixin): # set a class for every sample in dataset prediction = np.full((xp.shape[0], 1), node._class) return prediction, indices - u, i_u, d, i_d, _, _ = self._split_data(node, xp, indices) - k, l = predict_class(d, i_d, node.get_down()) - m, n = predict_class(u, i_u, node.get_up()) - return np.append(k, m), np.append(l, n) + down = self._split_criteria(self._distances(node, xp)) + X_U, X_D = self._split_array(xp, down) + i_u, i_d = self._split_array(indices, down) + prx_u, prin_u = predict_class(X_U, i_u, node.get_up()) + prx_d, prin_d = predict_class(X_D, i_d, node.get_down()) + return np.append(prx_u, prx_d), np.append(prin_u, prin_d) # sklearn check check_is_fitted(self, ['tree_']) @@ -276,10 +284,15 @@ class Stree(BaseEstimator, ClassifierMixin): prediction = np.full((xp.shape[0], 1), node._class) prediction_proba = dist return np.append(prediction, prediction_proba, axis=1), indices - u, i_u, d, i_d, r_u, r_d = self._split_data(node, xp, indices) - k, l = predict_class(d, i_d, r_d, node.get_down()) - m, n = predict_class(u, i_u, r_u, node.get_up()) - return np.append(k, m), np.append(l, n) + distances = self._distances(node, xp) + down = self._split_criteria(distances) + + X_U, X_D = self._split_array(xp, down) + i_u, i_d = self._split_array(indices, down) + di_u, di_d = self._split_array(distances, down) + prx_u, prin_u = predict_class(X_U, i_u, di_u, node.get_up()) + prx_d, prin_d = predict_class(X_D, i_d, di_d, node.get_down()) + return np.append(prx_u, prx_d), np.append(prin_u, prin_d) # sklearn check check_is_fitted(self, ['tree_']) @@ -295,7 +308,7 @@ class Stree(BaseEstimator, ClassifierMixin): # Probability of being 1 result[:, 1] = 1 / (1 + np.exp(-result[:, 1])) result[:, 0] = 1 - result[:, 1] # Probability of being 0 - return self._reorder_results(result, indices, proba=True) + return self._reorder_results(result, indices) def score(self, X: np.array, y: np.array) -> float: """Return accuracy @@ -319,35 +332,3 @@ class Stree(BaseEstimator, ClassifierMixin): output += str(i) + '\n' return output - def get_folder(self) -> str: - return self.__folder - - def _save_datasets(self, tree: Snode, catalog: typing.TextIO, number: int): - """Save the dataset of the node in a csv file - - :param tree: node with data to save - :type tree: Snode - :param catalog: catalog file handler - :type catalog: typing.TextIO - :param number: sequential number for the generated file name - :type number: int - """ - data = np.append(tree._X, tree._y.reshape(-1, 1), axis=1) - name = f"{self.__folder}dataset{number}.csv" - np.savetxt(name, data, delimiter=",") - catalog.write(f"{name}, - {str(tree)}") - if tree.is_leaf(): - return - self._save_datasets(tree.get_down(), catalog, number + 1) - self._save_datasets(tree.get_up(), catalog, number + 2) - - def get_catalog_name(self): - return self.__folder + "catalog.txt" - - def save_sub_datasets(self): - """Save the every dataset stored in the tree to check with manual classifier - """ - if not os.path.isdir(self.__folder): - os.mkdir(self.__folder) - with open(self.get_catalog_name(), 'w', encoding='utf-8') as catalog: - self._save_datasets(self.tree_, catalog, 1) diff --git a/stree/tests/Strees_test.py b/stree/tests/Strees_test.py index 720b452..7e32cfe 100644 --- a/stree/tests/Strees_test.py +++ b/stree/tests/Strees_test.py @@ -107,24 +107,6 @@ class Stree_test(unittest.TestCase): res.append(y_original[row]) return res - def test_subdatasets(self): - """Check if the subdatasets files have the same labels as the original dataset - """ - self._clf.save_sub_datasets() - with open(self._clf.get_catalog_name()) as cat_file: - catalog = csv.reader(cat_file, delimiter=',') - for row in catalog: - X, y = self._get_Xy() - x_file, y_file = self._get_file_data(row[0]) - y_original = np.array(self._find_out(x_file, X, y), dtype=int) - self.assertTrue(np.array_equal(y_file, y_original)) - if os.path.isdir(self._clf.get_folder()): - try: - os.remove(f"{self._clf.get_folder()}*") - os.rmdir(self._clf.get_folder()) - except: - pass - def test_single_prediction(self): X, y = self._get_Xy() yp = self._clf.predict((X[0, :].reshape(-1, X.shape[1]))) @@ -141,10 +123,9 @@ class Stree_test(unittest.TestCase): X, y = self._get_Xy() accuracy_score = self._clf.score(X, y) yp = self._clf.predict(X) - right = (yp == y).astype(int) - accuracy_computed = sum(right) / len(y) + accuracy_computed = np.mean(yp == y) self.assertEqual(accuracy_score, accuracy_computed) - self.assertGreater(accuracy_score, 0.8) + self.assertGreater(accuracy_score, 0.9) def test_single_predict_proba(self): """Check that element 28 has a prediction different that the current label