Remove ipnb checkpoints

This commit is contained in:
2022-11-14 19:25:38 +01:00
parent a15a93a8df
commit a2561072a5

View File

@@ -1,635 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "afc3548e-91c2-4443-bd96-457a57a202cc",
"metadata": {},
"outputs": [],
"source": [
"from mdlp import MDLP\n",
"import pandas as pd\n",
"from benchmark import Datasets\n",
"from bayesclass import TAN"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8ff3f4d6-e681-4252-ac4d-dc5bd14dcede",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>RI</th>\n",
" <th>Na</th>\n",
" <th>Mg</th>\n",
" <th>Al</th>\n",
" <th>Si</th>\n",
" <th>'K'</th>\n",
" <th>Ca</th>\n",
" <th>Ba</th>\n",
" <th>Fe</th>\n",
" <th>Type</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.51793</td>\n",
" <td>12.79</td>\n",
" <td>3.50</td>\n",
" <td>1.12</td>\n",
" <td>73.03</td>\n",
" <td>0.64</td>\n",
" <td>8.77</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.51643</td>\n",
" <td>12.16</td>\n",
" <td>3.52</td>\n",
" <td>1.35</td>\n",
" <td>72.89</td>\n",
" <td>0.57</td>\n",
" <td>8.53</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.51793</td>\n",
" <td>13.21</td>\n",
" <td>3.48</td>\n",
" <td>1.41</td>\n",
" <td>72.64</td>\n",
" <td>0.59</td>\n",
" <td>8.43</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.51299</td>\n",
" <td>14.40</td>\n",
" <td>1.74</td>\n",
" <td>1.54</td>\n",
" <td>74.55</td>\n",
" <td>0.00</td>\n",
" <td>7.59</td>\n",
" <td>0.0</td>\n",
" <td>0.00</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.53393</td>\n",
" <td>12.30</td>\n",
" <td>0.00</td>\n",
" <td>1.00</td>\n",
" <td>70.16</td>\n",
" <td>0.12</td>\n",
" <td>16.19</td>\n",
" <td>0.0</td>\n",
" <td>0.24</td>\n",
" <td>3</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" RI Na Mg Al Si 'K' Ca Ba Fe Type\n",
"0 1.51793 12.79 3.50 1.12 73.03 0.64 8.77 0.0 0.00 0\n",
"1 1.51643 12.16 3.52 1.35 72.89 0.57 8.53 0.0 0.00 1\n",
"2 1.51793 13.21 3.48 1.41 72.64 0.59 8.43 0.0 0.00 0\n",
"3 1.51299 14.40 1.74 1.54 74.55 0.00 7.59 0.0 0.00 2\n",
"4 1.53393 12.30 0.00 1.00 70.16 0.12 16.19 0.0 0.24 3"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Get data as a dataset\n",
"dt = Datasets()\n",
"data = dt.load(\"glass\", dataframe=True)\n",
"features = dt.dataset.features\n",
"class_name = dt.dataset.class_name\n",
"factorization, class_factors = pd.factorize(data[class_name])\n",
"data[class_name] = factorization\n",
"data.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7c9e1eae-6a66-4930-a125-f9f3def45574",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>RI</th>\n",
" <th>Na</th>\n",
" <th>Mg</th>\n",
" <th>Al</th>\n",
" <th>Si</th>\n",
" <th>'K'</th>\n",
" <th>Ca</th>\n",
" <th>Ba</th>\n",
" <th>Fe</th>\n",
" <th>Type</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>30.0</td>\n",
" <td>14.0</td>\n",
" <td>16.0</td>\n",
" <td>18.0</td>\n",
" <td>38.0</td>\n",
" <td>32.0</td>\n",
" <td>14.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>17.0</td>\n",
" <td>3.0</td>\n",
" <td>18.0</td>\n",
" <td>21.0</td>\n",
" <td>34.0</td>\n",
" <td>24.0</td>\n",
" <td>10.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>30.0</td>\n",
" <td>24.0</td>\n",
" <td>15.0</td>\n",
" <td>22.0</td>\n",
" <td>22.0</td>\n",
" <td>27.0</td>\n",
" <td>6.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3.0</td>\n",
" <td>51.0</td>\n",
" <td>6.0</td>\n",
" <td>23.0</td>\n",
" <td>47.0</td>\n",
" <td>0.0</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>62.0</td>\n",
" <td>4.0</td>\n",
" <td>0.0</td>\n",
" <td>13.0</td>\n",
" <td>0.0</td>\n",
" <td>8.0</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>209</th>\n",
" <td>13.0</td>\n",
" <td>33.0</td>\n",
" <td>11.0</td>\n",
" <td>19.0</td>\n",
" <td>23.0</td>\n",
" <td>27.0</td>\n",
" <td>4.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>210</th>\n",
" <td>11.0</td>\n",
" <td>19.0</td>\n",
" <td>18.0</td>\n",
" <td>29.0</td>\n",
" <td>23.0</td>\n",
" <td>33.0</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>211</th>\n",
" <td>14.0</td>\n",
" <td>41.0</td>\n",
" <td>18.0</td>\n",
" <td>20.0</td>\n",
" <td>34.0</td>\n",
" <td>14.0</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>5.0</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>212</th>\n",
" <td>20.0</td>\n",
" <td>8.0</td>\n",
" <td>8.0</td>\n",
" <td>23.0</td>\n",
" <td>42.0</td>\n",
" <td>33.0</td>\n",
" <td>11.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>213</th>\n",
" <td>43.0</td>\n",
" <td>46.0</td>\n",
" <td>6.0</td>\n",
" <td>23.0</td>\n",
" <td>23.0</td>\n",
" <td>0.0</td>\n",
" <td>15.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>214 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" RI Na Mg Al Si 'K' Ca Ba Fe Type\n",
"0 30.0 14.0 16.0 18.0 38.0 32.0 14.0 0.0 0.0 0\n",
"1 17.0 3.0 18.0 21.0 34.0 24.0 10.0 0.0 0.0 1\n",
"2 30.0 24.0 15.0 22.0 22.0 27.0 6.0 0.0 0.0 0\n",
"3 3.0 51.0 6.0 23.0 47.0 0.0 3.0 0.0 0.0 2\n",
"4 62.0 4.0 0.0 13.0 0.0 8.0 30.0 0.0 5.0 3\n",
".. ... ... ... ... ... ... ... ... ... ...\n",
"209 13.0 33.0 11.0 19.0 23.0 27.0 4.0 0.0 0.0 1\n",
"210 11.0 19.0 18.0 29.0 23.0 33.0 3.0 0.0 0.0 3\n",
"211 14.0 41.0 18.0 20.0 34.0 14.0 3.0 0.0 5.0 3\n",
"212 20.0 8.0 8.0 23.0 42.0 33.0 11.0 0.0 0.0 3\n",
"213 43.0 46.0 6.0 23.0 23.0 0.0 15.0 0.0 0.0 2\n",
"\n",
"[214 rows x 10 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Fayyad Irani\n",
"discretiz = MDLP()\n",
"Xdisc = discretiz.fit_transform(\n",
" data[features].to_numpy(), data[class_name].to_numpy()\n",
")\n",
"features_discretized = pd.DataFrame(Xdisc, columns=features)\n",
"dataset_discretized = features_discretized.copy()\n",
"dataset_discretized[class_name] = data[class_name]\n",
"X = dataset_discretized[features]\n",
"y = dataset_discretized[class_name]\n",
"dataset_discretized"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "2840a103-99fb-466f-ae75-45e11c1b9c5a",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import cross_validate, StratifiedKFold, KFold, cross_val_score\n",
"import numpy as np\n",
"n_folds = 5\n",
"score_name = \"accuracy\"\n",
"random_state=17\n",
"def validate_classifier(model, X, y, stratified, fit_params):\n",
" stratified_class = StratifiedKFold if stratified else KFold\n",
" kfold = stratified_class(shuffle=True, random_state=random_state, n_splits=n_folds)\n",
" #return cross_validate(model, X, y, cv=kfold, return_estimator=True, scoring=score_name)\n",
" return cross_val_score(model, X, y, fit_params=fit_params)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "6a1aad95-370f-4854-ae9a-32205aff5d39",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b620372c05294afc853885da0848e389",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/43 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ad33ee9f224d4abfa9a23338f07b32f2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/43 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fa05948fc73d43da8a3f01adc71c5e53",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/43 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1b67f117ceed43818c98342221b11184",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/43 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2399fb4327a242f0bd1c3e64b84382a6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/42 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n"
]
},
{
"ename": "IndexError",
"evalue": "only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [19], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m head \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m4\u001b[39m):\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m#model.fit(X, y, head=head, features=features, class_name=class_name)\u001b[39;00m\n\u001b[1;32m 8\u001b[0m score \u001b[38;5;241m=\u001b[39m validate_classifier(model, X, y, stratified\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, fit_params\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(head\u001b[38;5;241m=\u001b[39mhead, features\u001b[38;5;241m=\u001b[39mfeatures, class_name\u001b[38;5;241m=\u001b[39mclass_name))\n\u001b[0;32m----> 9\u001b[0m model\u001b[38;5;241m.\u001b[39mplot(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msimple_init=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msimple_init\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m head=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhead\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m score=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp\u001b[38;5;241m.\u001b[39mmean(score[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_score\u001b[39m\u001b[38;5;124m'\u001b[39m])\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"\u001b[0;31mIndexError\u001b[0m: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices"
]
}
],
"source": [
"import warnings\n",
"from stree import Stree\n",
"warnings.filterwarnings('ignore')\n",
"for simple_init in [False, True]:\n",
" model = TAN(simple_init=simple_init)\n",
" for head in range(4):\n",
" #model.fit(X, y, head=head, features=features, class_name=class_name)\n",
" score = validate_classifier(model, X, y, stratified=False, fit_params=dict(head=head, features=features, class_name=class_name))\n",
" #model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score['test_score'])}\")\n",
" model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c389ff1e-76d9-4c5b-9860-ea6d4752fac7",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c58629f-000b-4d8c-8896-efd032f1090c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"vscode": {
"interpreter": {
"hash": "a5f800306069c11c1b9a793f47dfeb8c7d63d06a771fda00cf3476e3d4088a52"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}