mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-18 17:15:53 +00:00
Remove ipnb checkpoints
This commit is contained in:
@@ -1,635 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "afc3548e-91c2-4443-bd96-457a57a202cc",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from mdlp import MDLP\n",
|
|
||||||
"import pandas as pd\n",
|
|
||||||
"from benchmark import Datasets\n",
|
|
||||||
"from bayesclass import TAN"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "8ff3f4d6-e681-4252-ac4d-dc5bd14dcede",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<div>\n",
|
|
||||||
"<style scoped>\n",
|
|
||||||
" .dataframe tbody tr th:only-of-type {\n",
|
|
||||||
" vertical-align: middle;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe tbody tr th {\n",
|
|
||||||
" vertical-align: top;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe thead th {\n",
|
|
||||||
" text-align: right;\n",
|
|
||||||
" }\n",
|
|
||||||
"</style>\n",
|
|
||||||
"<table border=\"1\" class=\"dataframe\">\n",
|
|
||||||
" <thead>\n",
|
|
||||||
" <tr style=\"text-align: right;\">\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th>RI</th>\n",
|
|
||||||
" <th>Na</th>\n",
|
|
||||||
" <th>Mg</th>\n",
|
|
||||||
" <th>Al</th>\n",
|
|
||||||
" <th>Si</th>\n",
|
|
||||||
" <th>'K'</th>\n",
|
|
||||||
" <th>Ca</th>\n",
|
|
||||||
" <th>Ba</th>\n",
|
|
||||||
" <th>Fe</th>\n",
|
|
||||||
" <th>Type</th>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </thead>\n",
|
|
||||||
" <tbody>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>0</th>\n",
|
|
||||||
" <td>1.51793</td>\n",
|
|
||||||
" <td>12.79</td>\n",
|
|
||||||
" <td>3.50</td>\n",
|
|
||||||
" <td>1.12</td>\n",
|
|
||||||
" <td>73.03</td>\n",
|
|
||||||
" <td>0.64</td>\n",
|
|
||||||
" <td>8.77</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.00</td>\n",
|
|
||||||
" <td>0</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>1</th>\n",
|
|
||||||
" <td>1.51643</td>\n",
|
|
||||||
" <td>12.16</td>\n",
|
|
||||||
" <td>3.52</td>\n",
|
|
||||||
" <td>1.35</td>\n",
|
|
||||||
" <td>72.89</td>\n",
|
|
||||||
" <td>0.57</td>\n",
|
|
||||||
" <td>8.53</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.00</td>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2</th>\n",
|
|
||||||
" <td>1.51793</td>\n",
|
|
||||||
" <td>13.21</td>\n",
|
|
||||||
" <td>3.48</td>\n",
|
|
||||||
" <td>1.41</td>\n",
|
|
||||||
" <td>72.64</td>\n",
|
|
||||||
" <td>0.59</td>\n",
|
|
||||||
" <td>8.43</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.00</td>\n",
|
|
||||||
" <td>0</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>3</th>\n",
|
|
||||||
" <td>1.51299</td>\n",
|
|
||||||
" <td>14.40</td>\n",
|
|
||||||
" <td>1.74</td>\n",
|
|
||||||
" <td>1.54</td>\n",
|
|
||||||
" <td>74.55</td>\n",
|
|
||||||
" <td>0.00</td>\n",
|
|
||||||
" <td>7.59</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.00</td>\n",
|
|
||||||
" <td>2</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>4</th>\n",
|
|
||||||
" <td>1.53393</td>\n",
|
|
||||||
" <td>12.30</td>\n",
|
|
||||||
" <td>0.00</td>\n",
|
|
||||||
" <td>1.00</td>\n",
|
|
||||||
" <td>70.16</td>\n",
|
|
||||||
" <td>0.12</td>\n",
|
|
||||||
" <td>16.19</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.24</td>\n",
|
|
||||||
" <td>3</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </tbody>\n",
|
|
||||||
"</table>\n",
|
|
||||||
"</div>"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
" RI Na Mg Al Si 'K' Ca Ba Fe Type\n",
|
|
||||||
"0 1.51793 12.79 3.50 1.12 73.03 0.64 8.77 0.0 0.00 0\n",
|
|
||||||
"1 1.51643 12.16 3.52 1.35 72.89 0.57 8.53 0.0 0.00 1\n",
|
|
||||||
"2 1.51793 13.21 3.48 1.41 72.64 0.59 8.43 0.0 0.00 0\n",
|
|
||||||
"3 1.51299 14.40 1.74 1.54 74.55 0.00 7.59 0.0 0.00 2\n",
|
|
||||||
"4 1.53393 12.30 0.00 1.00 70.16 0.12 16.19 0.0 0.24 3"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# Get data as a dataset\n",
|
|
||||||
"dt = Datasets()\n",
|
|
||||||
"data = dt.load(\"glass\", dataframe=True)\n",
|
|
||||||
"features = dt.dataset.features\n",
|
|
||||||
"class_name = dt.dataset.class_name\n",
|
|
||||||
"factorization, class_factors = pd.factorize(data[class_name])\n",
|
|
||||||
"data[class_name] = factorization\n",
|
|
||||||
"data.head()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "7c9e1eae-6a66-4930-a125-f9f3def45574",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/html": [
|
|
||||||
"<div>\n",
|
|
||||||
"<style scoped>\n",
|
|
||||||
" .dataframe tbody tr th:only-of-type {\n",
|
|
||||||
" vertical-align: middle;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe tbody tr th {\n",
|
|
||||||
" vertical-align: top;\n",
|
|
||||||
" }\n",
|
|
||||||
"\n",
|
|
||||||
" .dataframe thead th {\n",
|
|
||||||
" text-align: right;\n",
|
|
||||||
" }\n",
|
|
||||||
"</style>\n",
|
|
||||||
"<table border=\"1\" class=\"dataframe\">\n",
|
|
||||||
" <thead>\n",
|
|
||||||
" <tr style=\"text-align: right;\">\n",
|
|
||||||
" <th></th>\n",
|
|
||||||
" <th>RI</th>\n",
|
|
||||||
" <th>Na</th>\n",
|
|
||||||
" <th>Mg</th>\n",
|
|
||||||
" <th>Al</th>\n",
|
|
||||||
" <th>Si</th>\n",
|
|
||||||
" <th>'K'</th>\n",
|
|
||||||
" <th>Ca</th>\n",
|
|
||||||
" <th>Ba</th>\n",
|
|
||||||
" <th>Fe</th>\n",
|
|
||||||
" <th>Type</th>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </thead>\n",
|
|
||||||
" <tbody>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>0</th>\n",
|
|
||||||
" <td>30.0</td>\n",
|
|
||||||
" <td>14.0</td>\n",
|
|
||||||
" <td>16.0</td>\n",
|
|
||||||
" <td>18.0</td>\n",
|
|
||||||
" <td>38.0</td>\n",
|
|
||||||
" <td>32.0</td>\n",
|
|
||||||
" <td>14.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>1</th>\n",
|
|
||||||
" <td>17.0</td>\n",
|
|
||||||
" <td>3.0</td>\n",
|
|
||||||
" <td>18.0</td>\n",
|
|
||||||
" <td>21.0</td>\n",
|
|
||||||
" <td>34.0</td>\n",
|
|
||||||
" <td>24.0</td>\n",
|
|
||||||
" <td>10.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>2</th>\n",
|
|
||||||
" <td>30.0</td>\n",
|
|
||||||
" <td>24.0</td>\n",
|
|
||||||
" <td>15.0</td>\n",
|
|
||||||
" <td>22.0</td>\n",
|
|
||||||
" <td>22.0</td>\n",
|
|
||||||
" <td>27.0</td>\n",
|
|
||||||
" <td>6.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>3</th>\n",
|
|
||||||
" <td>3.0</td>\n",
|
|
||||||
" <td>51.0</td>\n",
|
|
||||||
" <td>6.0</td>\n",
|
|
||||||
" <td>23.0</td>\n",
|
|
||||||
" <td>47.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>3.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>2</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>4</th>\n",
|
|
||||||
" <td>62.0</td>\n",
|
|
||||||
" <td>4.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>13.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>8.0</td>\n",
|
|
||||||
" <td>30.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>5.0</td>\n",
|
|
||||||
" <td>3</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>...</th>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" <td>...</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>209</th>\n",
|
|
||||||
" <td>13.0</td>\n",
|
|
||||||
" <td>33.0</td>\n",
|
|
||||||
" <td>11.0</td>\n",
|
|
||||||
" <td>19.0</td>\n",
|
|
||||||
" <td>23.0</td>\n",
|
|
||||||
" <td>27.0</td>\n",
|
|
||||||
" <td>4.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>1</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>210</th>\n",
|
|
||||||
" <td>11.0</td>\n",
|
|
||||||
" <td>19.0</td>\n",
|
|
||||||
" <td>18.0</td>\n",
|
|
||||||
" <td>29.0</td>\n",
|
|
||||||
" <td>23.0</td>\n",
|
|
||||||
" <td>33.0</td>\n",
|
|
||||||
" <td>3.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>3</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>211</th>\n",
|
|
||||||
" <td>14.0</td>\n",
|
|
||||||
" <td>41.0</td>\n",
|
|
||||||
" <td>18.0</td>\n",
|
|
||||||
" <td>20.0</td>\n",
|
|
||||||
" <td>34.0</td>\n",
|
|
||||||
" <td>14.0</td>\n",
|
|
||||||
" <td>3.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>5.0</td>\n",
|
|
||||||
" <td>3</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>212</th>\n",
|
|
||||||
" <td>20.0</td>\n",
|
|
||||||
" <td>8.0</td>\n",
|
|
||||||
" <td>8.0</td>\n",
|
|
||||||
" <td>23.0</td>\n",
|
|
||||||
" <td>42.0</td>\n",
|
|
||||||
" <td>33.0</td>\n",
|
|
||||||
" <td>11.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>3</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" <tr>\n",
|
|
||||||
" <th>213</th>\n",
|
|
||||||
" <td>43.0</td>\n",
|
|
||||||
" <td>46.0</td>\n",
|
|
||||||
" <td>6.0</td>\n",
|
|
||||||
" <td>23.0</td>\n",
|
|
||||||
" <td>23.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>15.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>0.0</td>\n",
|
|
||||||
" <td>2</td>\n",
|
|
||||||
" </tr>\n",
|
|
||||||
" </tbody>\n",
|
|
||||||
"</table>\n",
|
|
||||||
"<p>214 rows × 10 columns</p>\n",
|
|
||||||
"</div>"
|
|
||||||
],
|
|
||||||
"text/plain": [
|
|
||||||
" RI Na Mg Al Si 'K' Ca Ba Fe Type\n",
|
|
||||||
"0 30.0 14.0 16.0 18.0 38.0 32.0 14.0 0.0 0.0 0\n",
|
|
||||||
"1 17.0 3.0 18.0 21.0 34.0 24.0 10.0 0.0 0.0 1\n",
|
|
||||||
"2 30.0 24.0 15.0 22.0 22.0 27.0 6.0 0.0 0.0 0\n",
|
|
||||||
"3 3.0 51.0 6.0 23.0 47.0 0.0 3.0 0.0 0.0 2\n",
|
|
||||||
"4 62.0 4.0 0.0 13.0 0.0 8.0 30.0 0.0 5.0 3\n",
|
|
||||||
".. ... ... ... ... ... ... ... ... ... ...\n",
|
|
||||||
"209 13.0 33.0 11.0 19.0 23.0 27.0 4.0 0.0 0.0 1\n",
|
|
||||||
"210 11.0 19.0 18.0 29.0 23.0 33.0 3.0 0.0 0.0 3\n",
|
|
||||||
"211 14.0 41.0 18.0 20.0 34.0 14.0 3.0 0.0 5.0 3\n",
|
|
||||||
"212 20.0 8.0 8.0 23.0 42.0 33.0 11.0 0.0 0.0 3\n",
|
|
||||||
"213 43.0 46.0 6.0 23.0 23.0 0.0 15.0 0.0 0.0 2\n",
|
|
||||||
"\n",
|
|
||||||
"[214 rows x 10 columns]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# Fayyad Irani\n",
|
|
||||||
"discretiz = MDLP()\n",
|
|
||||||
"Xdisc = discretiz.fit_transform(\n",
|
|
||||||
" data[features].to_numpy(), data[class_name].to_numpy()\n",
|
|
||||||
")\n",
|
|
||||||
"features_discretized = pd.DataFrame(Xdisc, columns=features)\n",
|
|
||||||
"dataset_discretized = features_discretized.copy()\n",
|
|
||||||
"dataset_discretized[class_name] = data[class_name]\n",
|
|
||||||
"X = dataset_discretized[features]\n",
|
|
||||||
"y = dataset_discretized[class_name]\n",
|
|
||||||
"dataset_discretized"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 17,
|
|
||||||
"id": "2840a103-99fb-466f-ae75-45e11c1b9c5a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from sklearn.model_selection import cross_validate, StratifiedKFold, KFold, cross_val_score\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"n_folds = 5\n",
|
|
||||||
"score_name = \"accuracy\"\n",
|
|
||||||
"random_state=17\n",
|
|
||||||
"def validate_classifier(model, X, y, stratified, fit_params):\n",
|
|
||||||
" stratified_class = StratifiedKFold if stratified else KFold\n",
|
|
||||||
" kfold = stratified_class(shuffle=True, random_state=random_state, n_splits=n_folds)\n",
|
|
||||||
" #return cross_validate(model, X, y, cv=kfold, return_estimator=True, scoring=score_name)\n",
|
|
||||||
" return cross_val_score(model, X, y, fit_params=fit_params)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 19,
|
|
||||||
"id": "6a1aad95-370f-4854-ae9a-32205aff5d39",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "b620372c05294afc853885da0848e389",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/43 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "ad33ee9f224d4abfa9a23338f07b32f2",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/43 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "fa05948fc73d43da8a3f01adc71c5e53",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/43 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "1b67f117ceed43818c98342221b11184",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/43 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "2399fb4327a242f0bd1c3e64b84382a6",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
" 0%| | 0/42 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n",
|
|
||||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
|
||||||
" warnings.warn(\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"ename": "IndexError",
|
|
||||||
"evalue": "only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
||||||
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
|
|
||||||
"Cell \u001b[0;32mIn [19], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m head \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m4\u001b[39m):\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m#model.fit(X, y, head=head, features=features, class_name=class_name)\u001b[39;00m\n\u001b[1;32m 8\u001b[0m score \u001b[38;5;241m=\u001b[39m validate_classifier(model, X, y, stratified\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, fit_params\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(head\u001b[38;5;241m=\u001b[39mhead, features\u001b[38;5;241m=\u001b[39mfeatures, class_name\u001b[38;5;241m=\u001b[39mclass_name))\n\u001b[0;32m----> 9\u001b[0m model\u001b[38;5;241m.\u001b[39mplot(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msimple_init=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msimple_init\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m head=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhead\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m score=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp\u001b[38;5;241m.\u001b[39mmean(score[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_score\u001b[39m\u001b[38;5;124m'\u001b[39m])\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
|
|
||||||
"\u001b[0;31mIndexError\u001b[0m: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import warnings\n",
|
|
||||||
"from stree import Stree\n",
|
|
||||||
"warnings.filterwarnings('ignore')\n",
|
|
||||||
"for simple_init in [False, True]:\n",
|
|
||||||
" model = TAN(simple_init=simple_init)\n",
|
|
||||||
" for head in range(4):\n",
|
|
||||||
" #model.fit(X, y, head=head, features=features, class_name=class_name)\n",
|
|
||||||
" score = validate_classifier(model, X, y, stratified=False, fit_params=dict(head=head, features=features, class_name=class_name))\n",
|
|
||||||
" #model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score['test_score'])}\")\n",
|
|
||||||
" model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score)}\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "c389ff1e-76d9-4c5b-9860-ea6d4752fac7",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "9c58629f-000b-4d8c-8896-efd032f1090c",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.10.8"
|
|
||||||
},
|
|
||||||
"vscode": {
|
|
||||||
"interpreter": {
|
|
||||||
"hash": "a5f800306069c11c1b9a793f47dfeb8c7d63d06a771fda00cf3476e3d4088a52"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
Reference in New Issue
Block a user