mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-17 00:26:10 +00:00
Update head hyperparam to use highest weight
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -1,12 +1,18 @@
|
||||
"""
|
||||
This is a module to be used as a reference for building other modules
|
||||
"""
|
||||
import random
|
||||
from itertools import combinations
|
||||
import pandas as pd
|
||||
from sklearn.base import ClassifierMixin, BaseEstimator
|
||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||
from sklearn.utils.multiclass import unique_labels
|
||||
import networkx as nx
|
||||
from pgmpy.estimators import TreeSearch, BayesianEstimator
|
||||
from pgmpy.estimators import (
|
||||
TreeSearch,
|
||||
BayesianEstimator,
|
||||
MaximumLikelihoodEstimator,
|
||||
)
|
||||
from pgmpy.models import BayesianNetwork
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@@ -29,9 +35,12 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
The classes seen at :meth:`fit`.
|
||||
"""
|
||||
|
||||
def __init__(self, simple_init=False, show_progress=False):
|
||||
def __init__(
|
||||
self, simple_init=False, show_progress=False, random_state=None
|
||||
):
|
||||
self.simple_init = simple_init
|
||||
self.show_progress = show_progress
|
||||
self.random_state = random_state
|
||||
|
||||
def fit(self, X, y, **kwargs):
|
||||
"""A reference implementation of a fitting function for a classifier.
|
||||
@@ -44,7 +53,8 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
**kwargs : dict
|
||||
class_name : str (default='class') Name of the class column
|
||||
features: list (default=None) List of features
|
||||
head: int (default=0) Index of the head node
|
||||
head: int (default=None) Index of the head node. Default value
|
||||
gets the node with the highest sum of weights (mutual_info)
|
||||
Returns
|
||||
-------
|
||||
self : object
|
||||
@@ -57,20 +67,22 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
# Default values
|
||||
self.class_name_ = "class"
|
||||
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
|
||||
self.head_ = 0
|
||||
self.head_ = None
|
||||
expected_args = ["class_name", "features", "head"]
|
||||
for key, value in kwargs.items():
|
||||
if key in expected_args:
|
||||
setattr(self, f"{key}_", value)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected argument: {key}")
|
||||
|
||||
if self.random_state is not None:
|
||||
random.seed(self.random_state)
|
||||
if self.head_ == "random":
|
||||
self.head_ = random.randint(0, len(self.features_) - 1)
|
||||
if len(self.features_) != X.shape[1]:
|
||||
raise ValueError(
|
||||
"Number of features does not match the number of columns in X"
|
||||
)
|
||||
if self.head_ >= len(self.features_):
|
||||
if self.head_ is not None and self.head_ >= len(self.features_):
|
||||
raise ValueError("Head index out of range")
|
||||
|
||||
self.X_ = X
|
||||
@@ -80,37 +92,57 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
return self
|
||||
|
||||
def __initial_edges(self):
|
||||
"""As with the naive Bayes, in a TAN structure, the class has no
|
||||
parents, while features must have the class as parent and are forced to
|
||||
have one other feature as parent too (except for one single feature,
|
||||
which has only the class as parent and is considered the root of the
|
||||
features' tree)
|
||||
Cassio P. de Campos, Giorgio Corani, Mauro Scanagatta, Marco Cuccu,
|
||||
Marco Zaffalon,
|
||||
Learning extended tree augmented naive structures,
|
||||
International Journal of Approximate Reasoning,
|
||||
Returns
|
||||
-------
|
||||
List
|
||||
List of edges
|
||||
"""
|
||||
head = 0 if self.head_ is None else self.head_
|
||||
if self.simple_init:
|
||||
first_node = self.features_[self.head_]
|
||||
first_node = self.features_[head]
|
||||
return [
|
||||
(first_node, feature)
|
||||
for feature in self.features_
|
||||
if feature != first_node
|
||||
]
|
||||
edges = []
|
||||
for i in range(len(self.features_)):
|
||||
for j in range(i + 1, len(self.features_)):
|
||||
edges.append((self.features_[i], self.features_[j]))
|
||||
return edges
|
||||
# initialize a complete network with all edges starting from head
|
||||
reordered = [
|
||||
self.features_[idx % len(self.features_)]
|
||||
for idx in range(head, len(self.features_) + head)
|
||||
]
|
||||
return list(combinations(reordered, 2))
|
||||
|
||||
def __train(self):
|
||||
# Initialize a Naive Bayes model
|
||||
net = [(self.class_name_, feature) for feature in self.features_]
|
||||
self.model_ = BayesianNetwork(net)
|
||||
# initialize a complete network with all edges
|
||||
self.model_.add_edges_from(self.__initial_edges())
|
||||
|
||||
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
|
||||
self.dataset_[self.class_name_] = self.y_
|
||||
# learn graph structure
|
||||
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
||||
root_node = None if self.head_ is None else self.features_[self.head_]
|
||||
est = TreeSearch(self.dataset_, root_node=root_node)
|
||||
dag = est.estimate(
|
||||
estimator_type="tan",
|
||||
class_node=self.class_name_,
|
||||
show_progress=self.show_progress,
|
||||
)
|
||||
if self.head_ is None:
|
||||
self.head_ = est.root_node
|
||||
self.model_ = BayesianNetwork(dag.edges())
|
||||
self.model_.fit(
|
||||
self.dataset_,
|
||||
# estimator=MaximumLikelihoodEstimator,
|
||||
estimator=BayesianEstimator,
|
||||
prior_type="K2",
|
||||
)
|
||||
|
@@ -16,12 +16,26 @@ def data():
|
||||
return enc.fit_transform(X), y
|
||||
|
||||
|
||||
def test_TAN_classifier(data):
|
||||
def test_TAN_constructor():
|
||||
clf = TAN()
|
||||
|
||||
# Test default values of hyperparameters
|
||||
assert not clf.simple_init
|
||||
assert not clf.show_progress
|
||||
assert clf.random_state is None
|
||||
clf = TAN(simple_init=True, show_progress=True, random_state=17)
|
||||
assert clf.simple_init
|
||||
assert clf.show_progress
|
||||
assert clf.random_state == 17
|
||||
|
||||
|
||||
def test_TAN_random_head(data):
|
||||
clf = TAN(random_state=17)
|
||||
clf.fit(*data, head="random")
|
||||
assert clf.head_ == 3
|
||||
|
||||
|
||||
def test_TAN_classifier(data):
|
||||
clf = TAN()
|
||||
|
||||
clf.fit(*data)
|
||||
attribs = ["classes_", "X_", "y_", "head_", "features_", "class_name_"]
|
||||
|
432
test.ipynb
432
test.ipynb
@@ -189,39 +189,39 @@
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>31.0</td>\n",
|
||||
" <td>8.0</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>13.0</td>\n",
|
||||
" <td>30.0</td>\n",
|
||||
" <td>14.0</td>\n",
|
||||
" <td>16.0</td>\n",
|
||||
" <td>18.0</td>\n",
|
||||
" <td>38.0</td>\n",
|
||||
" <td>26.0</td>\n",
|
||||
" <td>9.0</td>\n",
|
||||
" <td>32.0</td>\n",
|
||||
" <td>14.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>23.0</td>\n",
|
||||
" <td>17.0</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>19.0</td>\n",
|
||||
" <td>36.0</td>\n",
|
||||
" <td>19.0</td>\n",
|
||||
" <td>9.0</td>\n",
|
||||
" <td>18.0</td>\n",
|
||||
" <td>21.0</td>\n",
|
||||
" <td>34.0</td>\n",
|
||||
" <td>24.0</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>31.0</td>\n",
|
||||
" <td>17.0</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>20.0</td>\n",
|
||||
" <td>30.0</td>\n",
|
||||
" <td>24.0</td>\n",
|
||||
" <td>21.0</td>\n",
|
||||
" <td>7.0</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>22.0</td>\n",
|
||||
" <td>22.0</td>\n",
|
||||
" <td>27.0</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0</td>\n",
|
||||
@@ -229,9 +229,9 @@
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>42.0</td>\n",
|
||||
" <td>51.0</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" <td>21.0</td>\n",
|
||||
" <td>23.0</td>\n",
|
||||
" <td>47.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
@@ -241,15 +241,15 @@
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>63.0</td>\n",
|
||||
" <td>62.0</td>\n",
|
||||
" <td>4.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>11.0</td>\n",
|
||||
" <td>13.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>8.0</td>\n",
|
||||
" <td>21.0</td>\n",
|
||||
" <td>30.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>4.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
@@ -267,12 +267,12 @@
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>209</th>\n",
|
||||
" <td>17.0</td>\n",
|
||||
" <td>22.0</td>\n",
|
||||
" <td>14.0</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>26.0</td>\n",
|
||||
" <td>21.0</td>\n",
|
||||
" <td>13.0</td>\n",
|
||||
" <td>33.0</td>\n",
|
||||
" <td>11.0</td>\n",
|
||||
" <td>19.0</td>\n",
|
||||
" <td>23.0</td>\n",
|
||||
" <td>27.0</td>\n",
|
||||
" <td>4.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
@@ -280,12 +280,12 @@
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>210</th>\n",
|
||||
" <td>14.0</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>27.0</td>\n",
|
||||
" <td>25.0</td>\n",
|
||||
" <td>30.0</td>\n",
|
||||
" <td>11.0</td>\n",
|
||||
" <td>19.0</td>\n",
|
||||
" <td>18.0</td>\n",
|
||||
" <td>29.0</td>\n",
|
||||
" <td>23.0</td>\n",
|
||||
" <td>33.0</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
@@ -293,39 +293,39 @@
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>211</th>\n",
|
||||
" <td>19.0</td>\n",
|
||||
" <td>33.0</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>17.0</td>\n",
|
||||
" <td>36.0</td>\n",
|
||||
" <td>12.0</td>\n",
|
||||
" <td>14.0</td>\n",
|
||||
" <td>41.0</td>\n",
|
||||
" <td>18.0</td>\n",
|
||||
" <td>20.0</td>\n",
|
||||
" <td>34.0</td>\n",
|
||||
" <td>14.0</td>\n",
|
||||
" <td>3.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>4.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>212</th>\n",
|
||||
" <td>23.0</td>\n",
|
||||
" <td>5.0</td>\n",
|
||||
" <td>20.0</td>\n",
|
||||
" <td>8.0</td>\n",
|
||||
" <td>21.0</td>\n",
|
||||
" <td>43.0</td>\n",
|
||||
" <td>30.0</td>\n",
|
||||
" <td>9.0</td>\n",
|
||||
" <td>8.0</td>\n",
|
||||
" <td>23.0</td>\n",
|
||||
" <td>42.0</td>\n",
|
||||
" <td>33.0</td>\n",
|
||||
" <td>11.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>3</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>213</th>\n",
|
||||
" <td>44.0</td>\n",
|
||||
" <td>38.0</td>\n",
|
||||
" <td>43.0</td>\n",
|
||||
" <td>46.0</td>\n",
|
||||
" <td>6.0</td>\n",
|
||||
" <td>21.0</td>\n",
|
||||
" <td>25.0</td>\n",
|
||||
" <td>23.0</td>\n",
|
||||
" <td>23.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>10.0</td>\n",
|
||||
" <td>15.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>0.0</td>\n",
|
||||
" <td>2</td>\n",
|
||||
@@ -337,17 +337,17 @@
|
||||
],
|
||||
"text/plain": [
|
||||
" RI Na Mg Al Si 'K' Ca Ba Fe Type\n",
|
||||
"0 31.0 8.0 15.0 13.0 38.0 26.0 9.0 0.0 0.0 0\n",
|
||||
"1 23.0 3.0 15.0 19.0 36.0 19.0 9.0 0.0 0.0 1\n",
|
||||
"2 31.0 17.0 15.0 20.0 24.0 21.0 7.0 0.0 0.0 0\n",
|
||||
"3 3.0 42.0 6.0 21.0 47.0 0.0 3.0 0.0 0.0 2\n",
|
||||
"4 63.0 4.0 0.0 11.0 0.0 8.0 21.0 0.0 4.0 3\n",
|
||||
"0 30.0 14.0 16.0 18.0 38.0 32.0 14.0 0.0 0.0 0\n",
|
||||
"1 17.0 3.0 18.0 21.0 34.0 24.0 10.0 0.0 0.0 1\n",
|
||||
"2 30.0 24.0 15.0 22.0 22.0 27.0 6.0 0.0 0.0 0\n",
|
||||
"3 3.0 51.0 6.0 23.0 47.0 0.0 3.0 0.0 0.0 2\n",
|
||||
"4 62.0 4.0 0.0 13.0 0.0 8.0 30.0 0.0 5.0 3\n",
|
||||
".. ... ... ... ... ... ... ... ... ... ...\n",
|
||||
"209 17.0 22.0 14.0 15.0 26.0 21.0 4.0 0.0 0.0 1\n",
|
||||
"210 14.0 10.0 15.0 27.0 25.0 30.0 3.0 0.0 0.0 3\n",
|
||||
"211 19.0 33.0 15.0 17.0 36.0 12.0 3.0 0.0 4.0 3\n",
|
||||
"212 23.0 5.0 8.0 21.0 43.0 30.0 9.0 0.0 0.0 3\n",
|
||||
"213 44.0 38.0 6.0 21.0 25.0 0.0 10.0 0.0 0.0 2\n",
|
||||
"209 13.0 33.0 11.0 19.0 23.0 27.0 4.0 0.0 0.0 1\n",
|
||||
"210 11.0 19.0 18.0 29.0 23.0 33.0 3.0 0.0 0.0 3\n",
|
||||
"211 14.0 41.0 18.0 20.0 34.0 14.0 3.0 0.0 5.0 3\n",
|
||||
"212 20.0 8.0 8.0 23.0 42.0 33.0 11.0 0.0 0.0 3\n",
|
||||
"213 43.0 46.0 6.0 23.0 23.0 0.0 15.0 0.0 0.0 2\n",
|
||||
"\n",
|
||||
"[214 rows x 10 columns]"
|
||||
]
|
||||
@@ -373,39 +373,317 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6a1aad95-370f-4854-ae9a-32205aff5d39",
|
||||
"execution_count": 17,
|
||||
"id": "2840a103-99fb-466f-ae75-45e11c1b9c5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for simple_init in [False, True]:\n",
|
||||
" model = TAN(simple_init=simple_init)\n",
|
||||
" for head in range(4):\n",
|
||||
" model.fit(X, y, head=head, features=features, class_name=class_name)\n",
|
||||
" ypred = model.predict(X)\n",
|
||||
" #model.plot(f\"simple_init={simple_init} head={head} score={model.predict(X)}\")"
|
||||
"from sklearn.model_selection import cross_validate, StratifiedKFold, KFold, cross_val_score\n",
|
||||
"import numpy as np\n",
|
||||
"n_folds = 5\n",
|
||||
"score_name = \"accuracy\"\n",
|
||||
"random_state=17\n",
|
||||
"def validate_classifier(model, X, y, stratified, fit_params):\n",
|
||||
" stratified_class = StratifiedKFold if stratified else KFold\n",
|
||||
" kfold = stratified_class(shuffle=True, random_state=random_state, n_splits=n_folds)\n",
|
||||
" #return cross_validate(model, X, y, cv=kfold, return_estimator=True, scoring=score_name)\n",
|
||||
" return cross_val_score(model, X, y, fit_params=fit_params)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "76905bf3",
|
||||
"execution_count": 20,
|
||||
"id": "6a1aad95-370f-4854-ae9a-32205aff5d39",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7bdb666c5e5140e688141356958b362f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/43 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "0110029f56a4451cb877eeaae42e8e3f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/43 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "1e8130ed7d3a4f4da499dc481df369ab",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/43 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "52adf5efe6874dcfa1e776c6a3c47f2c",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/43 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "27fc7a5b95434c7c86ab2ddde212e006",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
" 0%| | 0/42 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n",
|
||||
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "AttributeError",
|
||||
"evalue": "'TAN' object has no attribute 'model_'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn [20], line 10\u001b[0m\n\u001b[1;32m 8\u001b[0m score \u001b[38;5;241m=\u001b[39m validate_classifier(model, X, y, stratified\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, fit_params\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(head\u001b[38;5;241m=\u001b[39mhead, features\u001b[38;5;241m=\u001b[39mfeatures, class_name\u001b[38;5;241m=\u001b[39mclass_name))\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m#model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score['test_score'])}\")\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msimple_init=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43msimple_init\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m head=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mhead\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m score=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43mscore\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Code/bayesclass/bayesclass/bayesclass.py:148\u001b[0m, in \u001b[0;36mTAN.plot\u001b[0;34m(self, title)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mplot\u001b[39m(\u001b[38;5;28mself\u001b[39m, title\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 147\u001b[0m nx\u001b[38;5;241m.\u001b[39mdraw_circular(\n\u001b[0;32m--> 148\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_\u001b[49m,\n\u001b[1;32m 149\u001b[0m with_labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 150\u001b[0m arrowsize\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m30\u001b[39m,\n\u001b[1;32m 151\u001b[0m node_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m800\u001b[39m,\n\u001b[1;32m 152\u001b[0m alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.3\u001b[39m,\n\u001b[1;32m 153\u001b[0m font_weight\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbold\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 154\u001b[0m )\n\u001b[1;32m 155\u001b[0m plt\u001b[38;5;241m.\u001b[39mtitle(title)\n\u001b[1;32m 156\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n",
|
||||
"\u001b[0;31mAttributeError\u001b[0m: 'TAN' object has no attribute 'model_'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"from stree import Stree\n",
|
||||
"warnings.filterwarnings('ignore')\n",
|
||||
"for simple_init in [False, True]:\n",
|
||||
" model = TAN(simple_init=simple_init)\n",
|
||||
" for head in range(4):\n",
|
||||
" #model.fit(X, y, head=head, features=features, class_name=class_name)\n",
|
||||
" score = validate_classifier(model, X, y, stratified=False, fit_params=dict(head=head, features=features, class_name=class_name))\n",
|
||||
" #model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score['test_score'])}\")\n",
|
||||
" model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "c389ff1e-76d9-4c5b-9860-ea6d4752fac7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(214, 9)"
|
||||
"array([nan, nan, nan, nan, nan])"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"X.shape\n"
|
||||
"score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"id": "9c58629f-000b-4d8c-8896-efd032f1090c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"b 10\n",
|
||||
"c 9\n",
|
||||
"d 8\n",
|
||||
"e 7\n",
|
||||
"a 6\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from queue import PriorityQueue\n",
|
||||
"q = PriorityQueue()\n",
|
||||
"lista = ['b', 'c', 'd', 'e', 'a']\n",
|
||||
"for i, c in zip(lista, range(len(lista))):\n",
|
||||
" print(i,10-c)\n",
|
||||
" q.put(i,10-c)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "e2a768c0-3e21-48f3-b118-25408122d01c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"a\n",
|
||||
"b\n",
|
||||
"c\n",
|
||||
"d\n",
|
||||
"e\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"while not q.empty():\n",
|
||||
" print(q.get())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "96bb1acd-f450-4b9c-8f54-f020e23dfc14",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
Reference in New Issue
Block a user