Update head hyperparam to use highest weight

This commit is contained in:
2022-11-07 00:51:19 +01:00
parent 8c03fc6b67
commit 02110f7608
4 changed files with 630 additions and 279 deletions

File diff suppressed because one or more lines are too long

View File

@@ -1,12 +1,18 @@
"""
This is a module to be used as a reference for building other modules
"""
import random
from itertools import combinations
import pandas as pd
from sklearn.base import ClassifierMixin, BaseEstimator
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
from sklearn.utils.multiclass import unique_labels
import networkx as nx
from pgmpy.estimators import TreeSearch, BayesianEstimator
from pgmpy.estimators import (
TreeSearch,
BayesianEstimator,
MaximumLikelihoodEstimator,
)
from pgmpy.models import BayesianNetwork
import matplotlib.pyplot as plt
@@ -29,9 +35,12 @@ class TAN(ClassifierMixin, BaseEstimator):
The classes seen at :meth:`fit`.
"""
def __init__(self, simple_init=False, show_progress=False):
def __init__(
self, simple_init=False, show_progress=False, random_state=None
):
self.simple_init = simple_init
self.show_progress = show_progress
self.random_state = random_state
def fit(self, X, y, **kwargs):
"""A reference implementation of a fitting function for a classifier.
@@ -44,7 +53,8 @@ class TAN(ClassifierMixin, BaseEstimator):
**kwargs : dict
class_name : str (default='class') Name of the class column
features: list (default=None) List of features
head: int (default=0) Index of the head node
head: int (default=None) Index of the head node. Default value
gets the node with the highest sum of weights (mutual_info)
Returns
-------
self : object
@@ -57,20 +67,22 @@ class TAN(ClassifierMixin, BaseEstimator):
# Default values
self.class_name_ = "class"
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
self.head_ = 0
self.head_ = None
expected_args = ["class_name", "features", "head"]
for key, value in kwargs.items():
if key in expected_args:
setattr(self, f"{key}_", value)
else:
raise ValueError(f"Unexpected argument: {key}")
if self.random_state is not None:
random.seed(self.random_state)
if self.head_ == "random":
self.head_ = random.randint(0, len(self.features_) - 1)
if len(self.features_) != X.shape[1]:
raise ValueError(
"Number of features does not match the number of columns in X"
)
if self.head_ >= len(self.features_):
if self.head_ is not None and self.head_ >= len(self.features_):
raise ValueError("Head index out of range")
self.X_ = X
@@ -80,37 +92,57 @@ class TAN(ClassifierMixin, BaseEstimator):
return self
def __initial_edges(self):
"""As with the naive Bayes, in a TAN structure, the class has no
parents, while features must have the class as parent and are forced to
have one other feature as parent too (except for one single feature,
which has only the class as parent and is considered the root of the
features' tree)
Cassio P. de Campos, Giorgio Corani, Mauro Scanagatta, Marco Cuccu,
Marco Zaffalon,
Learning extended tree augmented naive structures,
International Journal of Approximate Reasoning,
Returns
-------
List
List of edges
"""
head = 0 if self.head_ is None else self.head_
if self.simple_init:
first_node = self.features_[self.head_]
first_node = self.features_[head]
return [
(first_node, feature)
for feature in self.features_
if feature != first_node
]
edges = []
for i in range(len(self.features_)):
for j in range(i + 1, len(self.features_)):
edges.append((self.features_[i], self.features_[j]))
return edges
# initialize a complete network with all edges starting from head
reordered = [
self.features_[idx % len(self.features_)]
for idx in range(head, len(self.features_) + head)
]
return list(combinations(reordered, 2))
def __train(self):
# Initialize a Naive Bayes model
net = [(self.class_name_, feature) for feature in self.features_]
self.model_ = BayesianNetwork(net)
# initialize a complete network with all edges
self.model_.add_edges_from(self.__initial_edges())
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
self.dataset_[self.class_name_] = self.y_
# learn graph structure
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
root_node = None if self.head_ is None else self.features_[self.head_]
est = TreeSearch(self.dataset_, root_node=root_node)
dag = est.estimate(
estimator_type="tan",
class_node=self.class_name_,
show_progress=self.show_progress,
)
if self.head_ is None:
self.head_ = est.root_node
self.model_ = BayesianNetwork(dag.edges())
self.model_.fit(
self.dataset_,
# estimator=MaximumLikelihoodEstimator,
estimator=BayesianEstimator,
prior_type="K2",
)

View File

@@ -16,12 +16,26 @@ def data():
return enc.fit_transform(X), y
def test_TAN_classifier(data):
def test_TAN_constructor():
clf = TAN()
# Test default values of hyperparameters
assert not clf.simple_init
assert not clf.show_progress
assert clf.random_state is None
clf = TAN(simple_init=True, show_progress=True, random_state=17)
assert clf.simple_init
assert clf.show_progress
assert clf.random_state == 17
def test_TAN_random_head(data):
clf = TAN(random_state=17)
clf.fit(*data, head="random")
assert clf.head_ == 3
def test_TAN_classifier(data):
clf = TAN()
clf.fit(*data)
attribs = ["classes_", "X_", "y_", "head_", "features_", "class_name_"]

View File

@@ -189,39 +189,39 @@
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>31.0</td>\n",
" <td>8.0</td>\n",
" <td>15.0</td>\n",
" <td>13.0</td>\n",
" <td>30.0</td>\n",
" <td>14.0</td>\n",
" <td>16.0</td>\n",
" <td>18.0</td>\n",
" <td>38.0</td>\n",
" <td>26.0</td>\n",
" <td>9.0</td>\n",
" <td>32.0</td>\n",
" <td>14.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>23.0</td>\n",
" <td>17.0</td>\n",
" <td>3.0</td>\n",
" <td>15.0</td>\n",
" <td>19.0</td>\n",
" <td>36.0</td>\n",
" <td>19.0</td>\n",
" <td>9.0</td>\n",
" <td>18.0</td>\n",
" <td>21.0</td>\n",
" <td>34.0</td>\n",
" <td>24.0</td>\n",
" <td>10.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>31.0</td>\n",
" <td>17.0</td>\n",
" <td>15.0</td>\n",
" <td>20.0</td>\n",
" <td>30.0</td>\n",
" <td>24.0</td>\n",
" <td>21.0</td>\n",
" <td>7.0</td>\n",
" <td>15.0</td>\n",
" <td>22.0</td>\n",
" <td>22.0</td>\n",
" <td>27.0</td>\n",
" <td>6.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0</td>\n",
@@ -229,9 +229,9 @@
" <tr>\n",
" <th>3</th>\n",
" <td>3.0</td>\n",
" <td>42.0</td>\n",
" <td>51.0</td>\n",
" <td>6.0</td>\n",
" <td>21.0</td>\n",
" <td>23.0</td>\n",
" <td>47.0</td>\n",
" <td>0.0</td>\n",
" <td>3.0</td>\n",
@@ -241,15 +241,15 @@
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>63.0</td>\n",
" <td>62.0</td>\n",
" <td>4.0</td>\n",
" <td>0.0</td>\n",
" <td>11.0</td>\n",
" <td>13.0</td>\n",
" <td>0.0</td>\n",
" <td>8.0</td>\n",
" <td>21.0</td>\n",
" <td>30.0</td>\n",
" <td>0.0</td>\n",
" <td>4.0</td>\n",
" <td>5.0</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
@@ -267,12 +267,12 @@
" </tr>\n",
" <tr>\n",
" <th>209</th>\n",
" <td>17.0</td>\n",
" <td>22.0</td>\n",
" <td>14.0</td>\n",
" <td>15.0</td>\n",
" <td>26.0</td>\n",
" <td>21.0</td>\n",
" <td>13.0</td>\n",
" <td>33.0</td>\n",
" <td>11.0</td>\n",
" <td>19.0</td>\n",
" <td>23.0</td>\n",
" <td>27.0</td>\n",
" <td>4.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
@@ -280,12 +280,12 @@
" </tr>\n",
" <tr>\n",
" <th>210</th>\n",
" <td>14.0</td>\n",
" <td>10.0</td>\n",
" <td>15.0</td>\n",
" <td>27.0</td>\n",
" <td>25.0</td>\n",
" <td>30.0</td>\n",
" <td>11.0</td>\n",
" <td>19.0</td>\n",
" <td>18.0</td>\n",
" <td>29.0</td>\n",
" <td>23.0</td>\n",
" <td>33.0</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
@@ -293,39 +293,39 @@
" </tr>\n",
" <tr>\n",
" <th>211</th>\n",
" <td>19.0</td>\n",
" <td>33.0</td>\n",
" <td>15.0</td>\n",
" <td>17.0</td>\n",
" <td>36.0</td>\n",
" <td>12.0</td>\n",
" <td>14.0</td>\n",
" <td>41.0</td>\n",
" <td>18.0</td>\n",
" <td>20.0</td>\n",
" <td>34.0</td>\n",
" <td>14.0</td>\n",
" <td>3.0</td>\n",
" <td>0.0</td>\n",
" <td>4.0</td>\n",
" <td>5.0</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>212</th>\n",
" <td>23.0</td>\n",
" <td>5.0</td>\n",
" <td>20.0</td>\n",
" <td>8.0</td>\n",
" <td>21.0</td>\n",
" <td>43.0</td>\n",
" <td>30.0</td>\n",
" <td>9.0</td>\n",
" <td>8.0</td>\n",
" <td>23.0</td>\n",
" <td>42.0</td>\n",
" <td>33.0</td>\n",
" <td>11.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>213</th>\n",
" <td>44.0</td>\n",
" <td>38.0</td>\n",
" <td>43.0</td>\n",
" <td>46.0</td>\n",
" <td>6.0</td>\n",
" <td>21.0</td>\n",
" <td>25.0</td>\n",
" <td>23.0</td>\n",
" <td>23.0</td>\n",
" <td>0.0</td>\n",
" <td>10.0</td>\n",
" <td>15.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>2</td>\n",
@@ -337,17 +337,17 @@
],
"text/plain": [
" RI Na Mg Al Si 'K' Ca Ba Fe Type\n",
"0 31.0 8.0 15.0 13.0 38.0 26.0 9.0 0.0 0.0 0\n",
"1 23.0 3.0 15.0 19.0 36.0 19.0 9.0 0.0 0.0 1\n",
"2 31.0 17.0 15.0 20.0 24.0 21.0 7.0 0.0 0.0 0\n",
"3 3.0 42.0 6.0 21.0 47.0 0.0 3.0 0.0 0.0 2\n",
"4 63.0 4.0 0.0 11.0 0.0 8.0 21.0 0.0 4.0 3\n",
"0 30.0 14.0 16.0 18.0 38.0 32.0 14.0 0.0 0.0 0\n",
"1 17.0 3.0 18.0 21.0 34.0 24.0 10.0 0.0 0.0 1\n",
"2 30.0 24.0 15.0 22.0 22.0 27.0 6.0 0.0 0.0 0\n",
"3 3.0 51.0 6.0 23.0 47.0 0.0 3.0 0.0 0.0 2\n",
"4 62.0 4.0 0.0 13.0 0.0 8.0 30.0 0.0 5.0 3\n",
".. ... ... ... ... ... ... ... ... ... ...\n",
"209 17.0 22.0 14.0 15.0 26.0 21.0 4.0 0.0 0.0 1\n",
"210 14.0 10.0 15.0 27.0 25.0 30.0 3.0 0.0 0.0 3\n",
"211 19.0 33.0 15.0 17.0 36.0 12.0 3.0 0.0 4.0 3\n",
"212 23.0 5.0 8.0 21.0 43.0 30.0 9.0 0.0 0.0 3\n",
"213 44.0 38.0 6.0 21.0 25.0 0.0 10.0 0.0 0.0 2\n",
"209 13.0 33.0 11.0 19.0 23.0 27.0 4.0 0.0 0.0 1\n",
"210 11.0 19.0 18.0 29.0 23.0 33.0 3.0 0.0 0.0 3\n",
"211 14.0 41.0 18.0 20.0 34.0 14.0 3.0 0.0 5.0 3\n",
"212 20.0 8.0 8.0 23.0 42.0 33.0 11.0 0.0 0.0 3\n",
"213 43.0 46.0 6.0 23.0 23.0 0.0 15.0 0.0 0.0 2\n",
"\n",
"[214 rows x 10 columns]"
]
@@ -373,39 +373,317 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a1aad95-370f-4854-ae9a-32205aff5d39",
"execution_count": 17,
"id": "2840a103-99fb-466f-ae75-45e11c1b9c5a",
"metadata": {},
"outputs": [],
"source": [
"for simple_init in [False, True]:\n",
" model = TAN(simple_init=simple_init)\n",
" for head in range(4):\n",
" model.fit(X, y, head=head, features=features, class_name=class_name)\n",
" ypred = model.predict(X)\n",
" #model.plot(f\"simple_init={simple_init} head={head} score={model.predict(X)}\")"
"from sklearn.model_selection import cross_validate, StratifiedKFold, KFold, cross_val_score\n",
"import numpy as np\n",
"n_folds = 5\n",
"score_name = \"accuracy\"\n",
"random_state=17\n",
"def validate_classifier(model, X, y, stratified, fit_params):\n",
" stratified_class = StratifiedKFold if stratified else KFold\n",
" kfold = stratified_class(shuffle=True, random_state=random_state, n_splits=n_folds)\n",
" #return cross_validate(model, X, y, cv=kfold, return_estimator=True, scoring=score_name)\n",
" return cross_val_score(model, X, y, fit_params=fit_params)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "76905bf3",
"execution_count": 20,
"id": "6a1aad95-370f-4854-ae9a-32205aff5d39",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7bdb666c5e5140e688141356958b362f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/43 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0110029f56a4451cb877eeaae42e8e3f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/43 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1e8130ed7d3a4f4da499dc481df369ab",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/43 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "52adf5efe6874dcfa1e776c6a3c47f2c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/43 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "27fc7a5b95434c7c86ab2ddde212e006",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/42 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n",
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
" warnings.warn(\n"
]
},
{
"ename": "AttributeError",
"evalue": "'TAN' object has no attribute 'model_'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [20], line 10\u001b[0m\n\u001b[1;32m 8\u001b[0m score \u001b[38;5;241m=\u001b[39m validate_classifier(model, X, y, stratified\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, fit_params\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(head\u001b[38;5;241m=\u001b[39mhead, features\u001b[38;5;241m=\u001b[39mfeatures, class_name\u001b[38;5;241m=\u001b[39mclass_name))\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m#model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score['test_score'])}\")\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msimple_init=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43msimple_init\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m head=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mhead\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m score=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43mscore\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Code/bayesclass/bayesclass/bayesclass.py:148\u001b[0m, in \u001b[0;36mTAN.plot\u001b[0;34m(self, title)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mplot\u001b[39m(\u001b[38;5;28mself\u001b[39m, title\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 147\u001b[0m nx\u001b[38;5;241m.\u001b[39mdraw_circular(\n\u001b[0;32m--> 148\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_\u001b[49m,\n\u001b[1;32m 149\u001b[0m with_labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 150\u001b[0m arrowsize\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m30\u001b[39m,\n\u001b[1;32m 151\u001b[0m node_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m800\u001b[39m,\n\u001b[1;32m 152\u001b[0m alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.3\u001b[39m,\n\u001b[1;32m 153\u001b[0m font_weight\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbold\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 154\u001b[0m )\n\u001b[1;32m 155\u001b[0m plt\u001b[38;5;241m.\u001b[39mtitle(title)\n\u001b[1;32m 156\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n",
"\u001b[0;31mAttributeError\u001b[0m: 'TAN' object has no attribute 'model_'"
]
}
],
"source": [
"import warnings\n",
"from stree import Stree\n",
"warnings.filterwarnings('ignore')\n",
"for simple_init in [False, True]:\n",
" model = TAN(simple_init=simple_init)\n",
" for head in range(4):\n",
" #model.fit(X, y, head=head, features=features, class_name=class_name)\n",
" score = validate_classifier(model, X, y, stratified=False, fit_params=dict(head=head, features=features, class_name=class_name))\n",
" #model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score['test_score'])}\")\n",
" model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score)}\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "c389ff1e-76d9-4c5b-9860-ea6d4752fac7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(214, 9)"
"array([nan, nan, nan, nan, nan])"
]
},
"execution_count": 5,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape\n"
"score"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "9c58629f-000b-4d8c-8896-efd032f1090c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"b 10\n",
"c 9\n",
"d 8\n",
"e 7\n",
"a 6\n"
]
}
],
"source": [
"from queue import PriorityQueue\n",
"q = PriorityQueue()\n",
"lista = ['b', 'c', 'd', 'e', 'a']\n",
"for i, c in zip(lista, range(len(lista))):\n",
" print(i,10-c)\n",
" q.put(i,10-c)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "e2a768c0-3e21-48f3-b118-25408122d01c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a\n",
"b\n",
"c\n",
"d\n",
"e\n"
]
}
],
"source": [
"while not q.empty():\n",
" print(q.get())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96bb1acd-f450-4b9c-8f54-f020e23dfc14",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {