diff --git a/.ipynb_checkpoints/test-checkpoint.ipynb b/.ipynb_checkpoints/test-checkpoint.ipynb index 48910d4..e92fe3c 100644 --- a/.ipynb_checkpoints/test-checkpoint.ipynb +++ b/.ipynb_checkpoints/test-checkpoint.ipynb @@ -190,25 +190,25 @@ " \n", " 0\n", " 30.0\n", - " 8.0\n", - " 11.0\n", - " 12.0\n", - " 44.0\n", - " 30.0\n", - " 4.0\n", + " 14.0\n", + " 16.0\n", + " 18.0\n", + " 38.0\n", + " 32.0\n", + " 14.0\n", " 0.0\n", " 0.0\n", " 0\n", " \n", " \n", " 1\n", - " 22.0\n", + " 17.0\n", " 3.0\n", - " 11.0\n", - " 13.0\n", - " 43.0\n", - " 27.0\n", - " 4.0\n", + " 18.0\n", + " 21.0\n", + " 34.0\n", + " 24.0\n", + " 10.0\n", " 0.0\n", " 0.0\n", " 1\n", @@ -216,12 +216,12 @@ " \n", " 2\n", " 30.0\n", + " 24.0\n", " 15.0\n", - " 11.0\n", - " 14.0\n", - " 29.0\n", - " 28.0\n", - " 4.0\n", + " 22.0\n", + " 22.0\n", + " 27.0\n", + " 6.0\n", " 0.0\n", " 0.0\n", " 0\n", @@ -229,10 +229,10 @@ " \n", " 3\n", " 3.0\n", - " 32.0\n", - " 1.0\n", - " 16.0\n", - " 52.0\n", + " 51.0\n", + " 6.0\n", + " 23.0\n", + " 47.0\n", " 0.0\n", " 3.0\n", " 0.0\n", @@ -241,13 +241,13 @@ " \n", " \n", " 4\n", - " 63.0\n", + " 62.0\n", " 4.0\n", " 0.0\n", - " 9.0\n", + " 13.0\n", " 0.0\n", - " 10.0\n", - " 11.0\n", + " 8.0\n", + " 30.0\n", " 0.0\n", " 5.0\n", " 3\n", @@ -267,12 +267,12 @@ " \n", " \n", " 209\n", - " 16.0\n", - " 19.0\n", - " 10.0\n", " 13.0\n", - " 32.0\n", - " 28.0\n", + " 33.0\n", + " 11.0\n", + " 19.0\n", + " 23.0\n", + " 27.0\n", " 4.0\n", " 0.0\n", " 0.0\n", @@ -280,12 +280,12 @@ " \n", " \n", " 210\n", - " 14.0\n", " 11.0\n", - " 11.0\n", - " 26.0\n", - " 30.0\n", - " 30.0\n", + " 19.0\n", + " 18.0\n", + " 29.0\n", + " 23.0\n", + " 33.0\n", " 3.0\n", " 0.0\n", " 0.0\n", @@ -293,12 +293,12 @@ " \n", " \n", " 211\n", + " 14.0\n", + " 41.0\n", " 18.0\n", - " 23.0\n", - " 11.0\n", - " 13.0\n", - " 42.0\n", - " 15.0\n", + " 20.0\n", + " 34.0\n", + " 14.0\n", " 3.0\n", " 0.0\n", " 5.0\n", @@ -306,13 +306,13 @@ " \n", " \n", " 212\n", - " 22.0\n", - " 5.0\n", - " 3.0\n", - " 16.0\n", - " 47.0\n", - " 30.0\n", - " 4.0\n", + " 20.0\n", + " 8.0\n", + " 8.0\n", + " 23.0\n", + " 42.0\n", + " 33.0\n", + " 11.0\n", " 0.0\n", " 0.0\n", " 3\n", @@ -320,12 +320,12 @@ " \n", " 213\n", " 43.0\n", - " 28.0\n", - " 1.0\n", - " 16.0\n", - " 31.0\n", + " 46.0\n", + " 6.0\n", + " 23.0\n", + " 23.0\n", " 0.0\n", - " 5.0\n", + " 15.0\n", " 0.0\n", " 0.0\n", " 2\n", @@ -337,17 +337,17 @@ ], "text/plain": [ " RI Na Mg Al Si 'K' Ca Ba Fe Type\n", - "0 30.0 8.0 11.0 12.0 44.0 30.0 4.0 0.0 0.0 0\n", - "1 22.0 3.0 11.0 13.0 43.0 27.0 4.0 0.0 0.0 1\n", - "2 30.0 15.0 11.0 14.0 29.0 28.0 4.0 0.0 0.0 0\n", - "3 3.0 32.0 1.0 16.0 52.0 0.0 3.0 0.0 0.0 2\n", - "4 63.0 4.0 0.0 9.0 0.0 10.0 11.0 0.0 5.0 3\n", + "0 30.0 14.0 16.0 18.0 38.0 32.0 14.0 0.0 0.0 0\n", + "1 17.0 3.0 18.0 21.0 34.0 24.0 10.0 0.0 0.0 1\n", + "2 30.0 24.0 15.0 22.0 22.0 27.0 6.0 0.0 0.0 0\n", + "3 3.0 51.0 6.0 23.0 47.0 0.0 3.0 0.0 0.0 2\n", + "4 62.0 4.0 0.0 13.0 0.0 8.0 30.0 0.0 5.0 3\n", ".. ... ... ... ... ... ... ... ... ... ...\n", - "209 16.0 19.0 10.0 13.0 32.0 28.0 4.0 0.0 0.0 1\n", - "210 14.0 11.0 11.0 26.0 30.0 30.0 3.0 0.0 0.0 3\n", - "211 18.0 23.0 11.0 13.0 42.0 15.0 3.0 0.0 5.0 3\n", - "212 22.0 5.0 3.0 16.0 47.0 30.0 4.0 0.0 0.0 3\n", - "213 43.0 28.0 1.0 16.0 31.0 0.0 5.0 0.0 0.0 2\n", + "209 13.0 33.0 11.0 19.0 23.0 27.0 4.0 0.0 0.0 1\n", + "210 11.0 19.0 18.0 29.0 23.0 33.0 3.0 0.0 0.0 3\n", + "211 14.0 41.0 18.0 20.0 34.0 14.0 3.0 0.0 5.0 3\n", + "212 20.0 8.0 8.0 23.0 42.0 33.0 11.0 0.0 0.0 3\n", + "213 43.0 46.0 6.0 23.0 23.0 0.0 15.0 0.0 0.0 2\n", "\n", "[214 rows x 10 columns]" ] @@ -373,210 +373,237 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 17, + "id": "2840a103-99fb-466f-ae75-45e11c1b9c5a", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import cross_validate, StratifiedKFold, KFold, cross_val_score\n", + "import numpy as np\n", + "n_folds = 5\n", + "score_name = \"accuracy\"\n", + "random_state=17\n", + "def validate_classifier(model, X, y, stratified, fit_params):\n", + " stratified_class = StratifiedKFold if stratified else KFold\n", + " kfold = stratified_class(shuffle=True, random_state=random_state, n_splits=n_folds)\n", + " #return cross_validate(model, X, y, cv=kfold, return_estimator=True, scoring=score_name)\n", + " return cross_val_score(model, X, y, fit_params=fit_params)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, "id": "6a1aad95-370f-4854-ae9a-32205aff5d39", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b7c115c93e41439fa707a7a53d4e09de", + "model_id": "b620372c05294afc853885da0848e389", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/210 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n" + ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "42459beb4c964bd9a7e42993c315407e", + "model_id": "ad33ee9f224d4abfa9a23338f07b32f2", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/210 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n" + ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "56c8d86e92d84b6a900554896791e2f8", + "model_id": "fa05948fc73d43da8a3f01adc71c5e53", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/210 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n" + ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fbb1f212cc4f4977bf820f589d706f10", + "model_id": "1b67f117ceed43818c98342221b11184", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/210 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n" + ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2ba8efb4ba7c4104bc301b57bd8e6a74", + "model_id": "2399fb4327a242f0bd1c3e64b84382a6", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/210 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n", + "/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n", + " warnings.warn(\n" + ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0b254f30f950426aad8ab7a186ba5305", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/210 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "704e76b1646c4ee49960d95390d2a6d1", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/210 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "302b72893f344087affbbdc5c5a44c50", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/210 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "ename": "IndexError", + "evalue": "only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn [19], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m head \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m4\u001b[39m):\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m#model.fit(X, y, head=head, features=features, class_name=class_name)\u001b[39;00m\n\u001b[1;32m 8\u001b[0m score \u001b[38;5;241m=\u001b[39m validate_classifier(model, X, y, stratified\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, fit_params\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(head\u001b[38;5;241m=\u001b[39mhead, features\u001b[38;5;241m=\u001b[39mfeatures, class_name\u001b[38;5;241m=\u001b[39mclass_name))\n\u001b[0;32m----> 9\u001b[0m model\u001b[38;5;241m.\u001b[39mplot(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msimple_init=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00msimple_init\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m head=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhead\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m score=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnp\u001b[38;5;241m.\u001b[39mmean(score[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_score\u001b[39m\u001b[38;5;124m'\u001b[39m])\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mIndexError\u001b[0m: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices" + ] } ], "source": [ + "import warnings\n", + "from stree import Stree\n", + "warnings.filterwarnings('ignore')\n", "for simple_init in [False, True]:\n", " model = TAN(simple_init=simple_init)\n", " for head in range(4):\n", - " model.fit(X, y, head=head, features=features, class_name=class_name)\n", - " model.plot(f\"simple_init={simple_init} head={head} score={model.score(X, y)}\")" + " #model.fit(X, y, head=head, features=features, class_name=class_name)\n", + " score = validate_classifier(model, X, y, stratified=False, fit_params=dict(head=head, features=features, class_name=class_name))\n", + " #model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score['test_score'])}\")\n", + " model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score)}\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c389ff1e-76d9-4c5b-9860-ea6d4752fac7", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c58629f-000b-4d8c-8896-efd032f1090c", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/bayesclass/bayesclass.py b/bayesclass/bayesclass.py index 17e14d3..63cade9 100644 --- a/bayesclass/bayesclass.py +++ b/bayesclass/bayesclass.py @@ -1,12 +1,18 @@ """ This is a module to be used as a reference for building other modules """ +import random +from itertools import combinations import pandas as pd from sklearn.base import ClassifierMixin, BaseEstimator from sklearn.utils.validation import check_X_y, check_array, check_is_fitted from sklearn.utils.multiclass import unique_labels import networkx as nx -from pgmpy.estimators import TreeSearch, BayesianEstimator +from pgmpy.estimators import ( + TreeSearch, + BayesianEstimator, + MaximumLikelihoodEstimator, +) from pgmpy.models import BayesianNetwork import matplotlib.pyplot as plt @@ -29,9 +35,12 @@ class TAN(ClassifierMixin, BaseEstimator): The classes seen at :meth:`fit`. """ - def __init__(self, simple_init=False, show_progress=False): + def __init__( + self, simple_init=False, show_progress=False, random_state=None + ): self.simple_init = simple_init self.show_progress = show_progress + self.random_state = random_state def fit(self, X, y, **kwargs): """A reference implementation of a fitting function for a classifier. @@ -44,7 +53,8 @@ class TAN(ClassifierMixin, BaseEstimator): **kwargs : dict class_name : str (default='class') Name of the class column features: list (default=None) List of features - head: int (default=0) Index of the head node + head: int (default=None) Index of the head node. Default value + gets the node with the highest sum of weights (mutual_info) Returns ------- self : object @@ -57,20 +67,22 @@ class TAN(ClassifierMixin, BaseEstimator): # Default values self.class_name_ = "class" self.features_ = [f"feature_{i}" for i in range(X.shape[1])] - self.head_ = 0 + self.head_ = None expected_args = ["class_name", "features", "head"] for key, value in kwargs.items(): if key in expected_args: setattr(self, f"{key}_", value) - else: raise ValueError(f"Unexpected argument: {key}") - + if self.random_state is not None: + random.seed(self.random_state) + if self.head_ == "random": + self.head_ = random.randint(0, len(self.features_) - 1) if len(self.features_) != X.shape[1]: raise ValueError( "Number of features does not match the number of columns in X" ) - if self.head_ >= len(self.features_): + if self.head_ is not None and self.head_ >= len(self.features_): raise ValueError("Head index out of range") self.X_ = X @@ -80,37 +92,57 @@ class TAN(ClassifierMixin, BaseEstimator): return self def __initial_edges(self): + """As with the naive Bayes, in a TAN structure, the class has no + parents, while features must have the class as parent and are forced to + have one other feature as parent too (except for one single feature, + which has only the class as parent and is considered the root of the + features' tree) + Cassio P. de Campos, Giorgio Corani, Mauro Scanagatta, Marco Cuccu, + Marco Zaffalon, + Learning extended tree augmented naive structures, + International Journal of Approximate Reasoning, + Returns + ------- + List + List of edges + """ + head = 0 if self.head_ is None else self.head_ if self.simple_init: - first_node = self.features_[self.head_] + first_node = self.features_[head] return [ (first_node, feature) for feature in self.features_ if feature != first_node ] - edges = [] - for i in range(len(self.features_)): - for j in range(i + 1, len(self.features_)): - edges.append((self.features_[i], self.features_[j])) - return edges + # initialize a complete network with all edges starting from head + reordered = [ + self.features_[idx % len(self.features_)] + for idx in range(head, len(self.features_) + head) + ] + return list(combinations(reordered, 2)) def __train(self): + # Initialize a Naive Bayes model net = [(self.class_name_, feature) for feature in self.features_] self.model_ = BayesianNetwork(net) # initialize a complete network with all edges self.model_.add_edges_from(self.__initial_edges()) - self.dataset_ = pd.DataFrame(self.X_, columns=self.features_) self.dataset_[self.class_name_] = self.y_ # learn graph structure - est = TreeSearch(self.dataset_, root_node=self.features_[self.head_]) + root_node = None if self.head_ is None else self.features_[self.head_] + est = TreeSearch(self.dataset_, root_node=root_node) dag = est.estimate( estimator_type="tan", class_node=self.class_name_, show_progress=self.show_progress, ) + if self.head_ is None: + self.head_ = est.root_node self.model_ = BayesianNetwork(dag.edges()) self.model_.fit( self.dataset_, + # estimator=MaximumLikelihoodEstimator, estimator=BayesianEstimator, prior_type="K2", ) diff --git a/bayesclass/tests/test_bayesclass.py b/bayesclass/tests/test_bayesclass.py index 24a8dd6..c805878 100644 --- a/bayesclass/tests/test_bayesclass.py +++ b/bayesclass/tests/test_bayesclass.py @@ -16,12 +16,26 @@ def data(): return enc.fit_transform(X), y -def test_TAN_classifier(data): +def test_TAN_constructor(): clf = TAN() - # Test default values of hyperparameters assert not clf.simple_init assert not clf.show_progress + assert clf.random_state is None + clf = TAN(simple_init=True, show_progress=True, random_state=17) + assert clf.simple_init + assert clf.show_progress + assert clf.random_state == 17 + + +def test_TAN_random_head(data): + clf = TAN(random_state=17) + clf.fit(*data, head="random") + assert clf.head_ == 3 + + +def test_TAN_classifier(data): + clf = TAN() clf.fit(*data) attribs = ["classes_", "X_", "y_", "head_", "features_", "class_name_"] diff --git a/test.ipynb b/test.ipynb index d59983b..6a42616 100644 --- a/test.ipynb +++ b/test.ipynb @@ -189,39 +189,39 @@ " \n", " \n", " 0\n", - " 31.0\n", - " 8.0\n", - " 15.0\n", - " 13.0\n", + " 30.0\n", + " 14.0\n", + " 16.0\n", + " 18.0\n", " 38.0\n", - " 26.0\n", - " 9.0\n", + " 32.0\n", + " 14.0\n", " 0.0\n", " 0.0\n", " 0\n", " \n", " \n", " 1\n", - " 23.0\n", + " 17.0\n", " 3.0\n", - " 15.0\n", - " 19.0\n", - " 36.0\n", - " 19.0\n", - " 9.0\n", + " 18.0\n", + " 21.0\n", + " 34.0\n", + " 24.0\n", + " 10.0\n", " 0.0\n", " 0.0\n", " 1\n", " \n", " \n", " 2\n", - " 31.0\n", - " 17.0\n", - " 15.0\n", - " 20.0\n", + " 30.0\n", " 24.0\n", - " 21.0\n", - " 7.0\n", + " 15.0\n", + " 22.0\n", + " 22.0\n", + " 27.0\n", + " 6.0\n", " 0.0\n", " 0.0\n", " 0\n", @@ -229,9 +229,9 @@ " \n", " 3\n", " 3.0\n", - " 42.0\n", + " 51.0\n", " 6.0\n", - " 21.0\n", + " 23.0\n", " 47.0\n", " 0.0\n", " 3.0\n", @@ -241,15 +241,15 @@ " \n", " \n", " 4\n", - " 63.0\n", + " 62.0\n", " 4.0\n", " 0.0\n", - " 11.0\n", + " 13.0\n", " 0.0\n", " 8.0\n", - " 21.0\n", + " 30.0\n", " 0.0\n", - " 4.0\n", + " 5.0\n", " 3\n", " \n", " \n", @@ -267,12 +267,12 @@ " \n", " \n", " 209\n", - " 17.0\n", - " 22.0\n", - " 14.0\n", - " 15.0\n", - " 26.0\n", - " 21.0\n", + " 13.0\n", + " 33.0\n", + " 11.0\n", + " 19.0\n", + " 23.0\n", + " 27.0\n", " 4.0\n", " 0.0\n", " 0.0\n", @@ -280,12 +280,12 @@ " \n", " \n", " 210\n", - " 14.0\n", - " 10.0\n", - " 15.0\n", - " 27.0\n", - " 25.0\n", - " 30.0\n", + " 11.0\n", + " 19.0\n", + " 18.0\n", + " 29.0\n", + " 23.0\n", + " 33.0\n", " 3.0\n", " 0.0\n", " 0.0\n", @@ -293,39 +293,39 @@ " \n", " \n", " 211\n", - " 19.0\n", - " 33.0\n", - " 15.0\n", - " 17.0\n", - " 36.0\n", - " 12.0\n", + " 14.0\n", + " 41.0\n", + " 18.0\n", + " 20.0\n", + " 34.0\n", + " 14.0\n", " 3.0\n", " 0.0\n", - " 4.0\n", + " 5.0\n", " 3\n", " \n", " \n", " 212\n", - " 23.0\n", - " 5.0\n", + " 20.0\n", " 8.0\n", - " 21.0\n", - " 43.0\n", - " 30.0\n", - " 9.0\n", + " 8.0\n", + " 23.0\n", + " 42.0\n", + " 33.0\n", + " 11.0\n", " 0.0\n", " 0.0\n", " 3\n", " \n", " \n", " 213\n", - " 44.0\n", - " 38.0\n", + " 43.0\n", + " 46.0\n", " 6.0\n", - " 21.0\n", - " 25.0\n", + " 23.0\n", + " 23.0\n", " 0.0\n", - " 10.0\n", + " 15.0\n", " 0.0\n", " 0.0\n", " 2\n", @@ -337,17 +337,17 @@ ], "text/plain": [ " RI Na Mg Al Si 'K' Ca Ba Fe Type\n", - "0 31.0 8.0 15.0 13.0 38.0 26.0 9.0 0.0 0.0 0\n", - "1 23.0 3.0 15.0 19.0 36.0 19.0 9.0 0.0 0.0 1\n", - "2 31.0 17.0 15.0 20.0 24.0 21.0 7.0 0.0 0.0 0\n", - "3 3.0 42.0 6.0 21.0 47.0 0.0 3.0 0.0 0.0 2\n", - "4 63.0 4.0 0.0 11.0 0.0 8.0 21.0 0.0 4.0 3\n", + "0 30.0 14.0 16.0 18.0 38.0 32.0 14.0 0.0 0.0 0\n", + "1 17.0 3.0 18.0 21.0 34.0 24.0 10.0 0.0 0.0 1\n", + "2 30.0 24.0 15.0 22.0 22.0 27.0 6.0 0.0 0.0 0\n", + "3 3.0 51.0 6.0 23.0 47.0 0.0 3.0 0.0 0.0 2\n", + "4 62.0 4.0 0.0 13.0 0.0 8.0 30.0 0.0 5.0 3\n", ".. ... ... ... ... ... ... ... ... ... ...\n", - "209 17.0 22.0 14.0 15.0 26.0 21.0 4.0 0.0 0.0 1\n", - "210 14.0 10.0 15.0 27.0 25.0 30.0 3.0 0.0 0.0 3\n", - "211 19.0 33.0 15.0 17.0 36.0 12.0 3.0 0.0 4.0 3\n", - "212 23.0 5.0 8.0 21.0 43.0 30.0 9.0 0.0 0.0 3\n", - "213 44.0 38.0 6.0 21.0 25.0 0.0 10.0 0.0 0.0 2\n", + "209 13.0 33.0 11.0 19.0 23.0 27.0 4.0 0.0 0.0 1\n", + "210 11.0 19.0 18.0 29.0 23.0 33.0 3.0 0.0 0.0 3\n", + "211 14.0 41.0 18.0 20.0 34.0 14.0 3.0 0.0 5.0 3\n", + "212 20.0 8.0 8.0 23.0 42.0 33.0 11.0 0.0 0.0 3\n", + "213 43.0 46.0 6.0 23.0 23.0 0.0 15.0 0.0 0.0 2\n", "\n", "[214 rows x 10 columns]" ] @@ -373,39 +373,317 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "6a1aad95-370f-4854-ae9a-32205aff5d39", + "execution_count": 17, + "id": "2840a103-99fb-466f-ae75-45e11c1b9c5a", "metadata": {}, "outputs": [], "source": [ - "for simple_init in [False, True]:\n", - " model = TAN(simple_init=simple_init)\n", - " for head in range(4):\n", - " model.fit(X, y, head=head, features=features, class_name=class_name)\n", - " ypred = model.predict(X)\n", - " #model.plot(f\"simple_init={simple_init} head={head} score={model.predict(X)}\")" + "from sklearn.model_selection import cross_validate, StratifiedKFold, KFold, cross_val_score\n", + "import numpy as np\n", + "n_folds = 5\n", + "score_name = \"accuracy\"\n", + "random_state=17\n", + "def validate_classifier(model, X, y, stratified, fit_params):\n", + " stratified_class = StratifiedKFold if stratified else KFold\n", + " kfold = stratified_class(shuffle=True, random_state=random_state, n_splits=n_folds)\n", + " #return cross_validate(model, X, y, cv=kfold, return_estimator=True, scoring=score_name)\n", + " return cross_val_score(model, X, y, fit_params=fit_params)" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "76905bf3", + "execution_count": 20, + "id": "6a1aad95-370f-4854-ae9a-32205aff5d39", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7bdb666c5e5140e688141356958b362f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/43 [00:00 10\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msimple_init=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43msimple_init\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m head=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mhead\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m score=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43mscore\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Code/bayesclass/bayesclass/bayesclass.py:148\u001b[0m, in \u001b[0;36mTAN.plot\u001b[0;34m(self, title)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mplot\u001b[39m(\u001b[38;5;28mself\u001b[39m, title\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 147\u001b[0m nx\u001b[38;5;241m.\u001b[39mdraw_circular(\n\u001b[0;32m--> 148\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_\u001b[49m,\n\u001b[1;32m 149\u001b[0m with_labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 150\u001b[0m arrowsize\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m30\u001b[39m,\n\u001b[1;32m 151\u001b[0m node_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m800\u001b[39m,\n\u001b[1;32m 152\u001b[0m alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.3\u001b[39m,\n\u001b[1;32m 153\u001b[0m font_weight\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbold\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 154\u001b[0m )\n\u001b[1;32m 155\u001b[0m plt\u001b[38;5;241m.\u001b[39mtitle(title)\n\u001b[1;32m 156\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n", + "\u001b[0;31mAttributeError\u001b[0m: 'TAN' object has no attribute 'model_'" + ] + } + ], + "source": [ + "import warnings\n", + "from stree import Stree\n", + "warnings.filterwarnings('ignore')\n", + "for simple_init in [False, True]:\n", + " model = TAN(simple_init=simple_init)\n", + " for head in range(4):\n", + " #model.fit(X, y, head=head, features=features, class_name=class_name)\n", + " score = validate_classifier(model, X, y, stratified=False, fit_params=dict(head=head, features=features, class_name=class_name))\n", + " #model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score['test_score'])}\")\n", + " model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "c389ff1e-76d9-4c5b-9860-ea6d4752fac7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(214, 9)" + "array([nan, nan, nan, nan, nan])" ] }, - "execution_count": 5, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "X.shape\n" + "score" ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "9c58629f-000b-4d8c-8896-efd032f1090c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "b 10\n", + "c 9\n", + "d 8\n", + "e 7\n", + "a 6\n" + ] + } + ], + "source": [ + "from queue import PriorityQueue\n", + "q = PriorityQueue()\n", + "lista = ['b', 'c', 'd', 'e', 'a']\n", + "for i, c in zip(lista, range(len(lista))):\n", + " print(i,10-c)\n", + " q.put(i,10-c)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "e2a768c0-3e21-48f3-b118-25408122d01c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a\n", + "b\n", + "c\n", + "d\n", + "e\n" + ] + } + ], + "source": [ + "while not q.empty():\n", + " print(q.get())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96bb1acd-f450-4b9c-8f54-f020e23dfc14", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {