mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-17 16:45:54 +00:00
Update head hyperparam to use highest weight
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -1,12 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
This is a module to be used as a reference for building other modules
|
This is a module to be used as a reference for building other modules
|
||||||
"""
|
"""
|
||||||
|
import random
|
||||||
|
from itertools import combinations
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.base import ClassifierMixin, BaseEstimator
|
from sklearn.base import ClassifierMixin, BaseEstimator
|
||||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||||
from sklearn.utils.multiclass import unique_labels
|
from sklearn.utils.multiclass import unique_labels
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pgmpy.estimators import TreeSearch, BayesianEstimator
|
from pgmpy.estimators import (
|
||||||
|
TreeSearch,
|
||||||
|
BayesianEstimator,
|
||||||
|
MaximumLikelihoodEstimator,
|
||||||
|
)
|
||||||
from pgmpy.models import BayesianNetwork
|
from pgmpy.models import BayesianNetwork
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
@@ -29,9 +35,12 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
The classes seen at :meth:`fit`.
|
The classes seen at :meth:`fit`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, simple_init=False, show_progress=False):
|
def __init__(
|
||||||
|
self, simple_init=False, show_progress=False, random_state=None
|
||||||
|
):
|
||||||
self.simple_init = simple_init
|
self.simple_init = simple_init
|
||||||
self.show_progress = show_progress
|
self.show_progress = show_progress
|
||||||
|
self.random_state = random_state
|
||||||
|
|
||||||
def fit(self, X, y, **kwargs):
|
def fit(self, X, y, **kwargs):
|
||||||
"""A reference implementation of a fitting function for a classifier.
|
"""A reference implementation of a fitting function for a classifier.
|
||||||
@@ -44,7 +53,8 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
**kwargs : dict
|
**kwargs : dict
|
||||||
class_name : str (default='class') Name of the class column
|
class_name : str (default='class') Name of the class column
|
||||||
features: list (default=None) List of features
|
features: list (default=None) List of features
|
||||||
head: int (default=0) Index of the head node
|
head: int (default=None) Index of the head node. Default value
|
||||||
|
gets the node with the highest sum of weights (mutual_info)
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
self : object
|
self : object
|
||||||
@@ -57,20 +67,22 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
# Default values
|
# Default values
|
||||||
self.class_name_ = "class"
|
self.class_name_ = "class"
|
||||||
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
|
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
|
||||||
self.head_ = 0
|
self.head_ = None
|
||||||
expected_args = ["class_name", "features", "head"]
|
expected_args = ["class_name", "features", "head"]
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key in expected_args:
|
if key in expected_args:
|
||||||
setattr(self, f"{key}_", value)
|
setattr(self, f"{key}_", value)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected argument: {key}")
|
raise ValueError(f"Unexpected argument: {key}")
|
||||||
|
if self.random_state is not None:
|
||||||
|
random.seed(self.random_state)
|
||||||
|
if self.head_ == "random":
|
||||||
|
self.head_ = random.randint(0, len(self.features_) - 1)
|
||||||
if len(self.features_) != X.shape[1]:
|
if len(self.features_) != X.shape[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Number of features does not match the number of columns in X"
|
"Number of features does not match the number of columns in X"
|
||||||
)
|
)
|
||||||
if self.head_ >= len(self.features_):
|
if self.head_ is not None and self.head_ >= len(self.features_):
|
||||||
raise ValueError("Head index out of range")
|
raise ValueError("Head index out of range")
|
||||||
|
|
||||||
self.X_ = X
|
self.X_ = X
|
||||||
@@ -80,37 +92,57 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def __initial_edges(self):
|
def __initial_edges(self):
|
||||||
|
"""As with the naive Bayes, in a TAN structure, the class has no
|
||||||
|
parents, while features must have the class as parent and are forced to
|
||||||
|
have one other feature as parent too (except for one single feature,
|
||||||
|
which has only the class as parent and is considered the root of the
|
||||||
|
features' tree)
|
||||||
|
Cassio P. de Campos, Giorgio Corani, Mauro Scanagatta, Marco Cuccu,
|
||||||
|
Marco Zaffalon,
|
||||||
|
Learning extended tree augmented naive structures,
|
||||||
|
International Journal of Approximate Reasoning,
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List
|
||||||
|
List of edges
|
||||||
|
"""
|
||||||
|
head = 0 if self.head_ is None else self.head_
|
||||||
if self.simple_init:
|
if self.simple_init:
|
||||||
first_node = self.features_[self.head_]
|
first_node = self.features_[head]
|
||||||
return [
|
return [
|
||||||
(first_node, feature)
|
(first_node, feature)
|
||||||
for feature in self.features_
|
for feature in self.features_
|
||||||
if feature != first_node
|
if feature != first_node
|
||||||
]
|
]
|
||||||
edges = []
|
# initialize a complete network with all edges starting from head
|
||||||
for i in range(len(self.features_)):
|
reordered = [
|
||||||
for j in range(i + 1, len(self.features_)):
|
self.features_[idx % len(self.features_)]
|
||||||
edges.append((self.features_[i], self.features_[j]))
|
for idx in range(head, len(self.features_) + head)
|
||||||
return edges
|
]
|
||||||
|
return list(combinations(reordered, 2))
|
||||||
|
|
||||||
def __train(self):
|
def __train(self):
|
||||||
|
# Initialize a Naive Bayes model
|
||||||
net = [(self.class_name_, feature) for feature in self.features_]
|
net = [(self.class_name_, feature) for feature in self.features_]
|
||||||
self.model_ = BayesianNetwork(net)
|
self.model_ = BayesianNetwork(net)
|
||||||
# initialize a complete network with all edges
|
# initialize a complete network with all edges
|
||||||
self.model_.add_edges_from(self.__initial_edges())
|
self.model_.add_edges_from(self.__initial_edges())
|
||||||
|
|
||||||
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
|
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
|
||||||
self.dataset_[self.class_name_] = self.y_
|
self.dataset_[self.class_name_] = self.y_
|
||||||
# learn graph structure
|
# learn graph structure
|
||||||
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
root_node = None if self.head_ is None else self.features_[self.head_]
|
||||||
|
est = TreeSearch(self.dataset_, root_node=root_node)
|
||||||
dag = est.estimate(
|
dag = est.estimate(
|
||||||
estimator_type="tan",
|
estimator_type="tan",
|
||||||
class_node=self.class_name_,
|
class_node=self.class_name_,
|
||||||
show_progress=self.show_progress,
|
show_progress=self.show_progress,
|
||||||
)
|
)
|
||||||
|
if self.head_ is None:
|
||||||
|
self.head_ = est.root_node
|
||||||
self.model_ = BayesianNetwork(dag.edges())
|
self.model_ = BayesianNetwork(dag.edges())
|
||||||
self.model_.fit(
|
self.model_.fit(
|
||||||
self.dataset_,
|
self.dataset_,
|
||||||
|
# estimator=MaximumLikelihoodEstimator,
|
||||||
estimator=BayesianEstimator,
|
estimator=BayesianEstimator,
|
||||||
prior_type="K2",
|
prior_type="K2",
|
||||||
)
|
)
|
||||||
|
@@ -16,12 +16,26 @@ def data():
|
|||||||
return enc.fit_transform(X), y
|
return enc.fit_transform(X), y
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_classifier(data):
|
def test_TAN_constructor():
|
||||||
clf = TAN()
|
clf = TAN()
|
||||||
|
|
||||||
# Test default values of hyperparameters
|
# Test default values of hyperparameters
|
||||||
assert not clf.simple_init
|
assert not clf.simple_init
|
||||||
assert not clf.show_progress
|
assert not clf.show_progress
|
||||||
|
assert clf.random_state is None
|
||||||
|
clf = TAN(simple_init=True, show_progress=True, random_state=17)
|
||||||
|
assert clf.simple_init
|
||||||
|
assert clf.show_progress
|
||||||
|
assert clf.random_state == 17
|
||||||
|
|
||||||
|
|
||||||
|
def test_TAN_random_head(data):
|
||||||
|
clf = TAN(random_state=17)
|
||||||
|
clf.fit(*data, head="random")
|
||||||
|
assert clf.head_ == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_TAN_classifier(data):
|
||||||
|
clf = TAN()
|
||||||
|
|
||||||
clf.fit(*data)
|
clf.fit(*data)
|
||||||
attribs = ["classes_", "X_", "y_", "head_", "features_", "class_name_"]
|
attribs = ["classes_", "X_", "y_", "head_", "features_", "class_name_"]
|
||||||
|
432
test.ipynb
432
test.ipynb
@@ -189,39 +189,39 @@
|
|||||||
" <tbody>\n",
|
" <tbody>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>0</th>\n",
|
" <th>0</th>\n",
|
||||||
" <td>31.0</td>\n",
|
" <td>30.0</td>\n",
|
||||||
" <td>8.0</td>\n",
|
" <td>14.0</td>\n",
|
||||||
" <td>15.0</td>\n",
|
" <td>16.0</td>\n",
|
||||||
" <td>13.0</td>\n",
|
" <td>18.0</td>\n",
|
||||||
" <td>38.0</td>\n",
|
" <td>38.0</td>\n",
|
||||||
" <td>26.0</td>\n",
|
" <td>32.0</td>\n",
|
||||||
" <td>9.0</td>\n",
|
" <td>14.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>0</td>\n",
|
" <td>0</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>1</th>\n",
|
" <th>1</th>\n",
|
||||||
" <td>23.0</td>\n",
|
" <td>17.0</td>\n",
|
||||||
" <td>3.0</td>\n",
|
" <td>3.0</td>\n",
|
||||||
" <td>15.0</td>\n",
|
" <td>18.0</td>\n",
|
||||||
" <td>19.0</td>\n",
|
" <td>21.0</td>\n",
|
||||||
" <td>36.0</td>\n",
|
" <td>34.0</td>\n",
|
||||||
" <td>19.0</td>\n",
|
" <td>24.0</td>\n",
|
||||||
" <td>9.0</td>\n",
|
" <td>10.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>1</td>\n",
|
" <td>1</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>2</th>\n",
|
" <th>2</th>\n",
|
||||||
" <td>31.0</td>\n",
|
" <td>30.0</td>\n",
|
||||||
" <td>17.0</td>\n",
|
|
||||||
" <td>15.0</td>\n",
|
|
||||||
" <td>20.0</td>\n",
|
|
||||||
" <td>24.0</td>\n",
|
" <td>24.0</td>\n",
|
||||||
" <td>21.0</td>\n",
|
" <td>15.0</td>\n",
|
||||||
" <td>7.0</td>\n",
|
" <td>22.0</td>\n",
|
||||||
|
" <td>22.0</td>\n",
|
||||||
|
" <td>27.0</td>\n",
|
||||||
|
" <td>6.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>0</td>\n",
|
" <td>0</td>\n",
|
||||||
@@ -229,9 +229,9 @@
|
|||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>3</th>\n",
|
" <th>3</th>\n",
|
||||||
" <td>3.0</td>\n",
|
" <td>3.0</td>\n",
|
||||||
" <td>42.0</td>\n",
|
" <td>51.0</td>\n",
|
||||||
" <td>6.0</td>\n",
|
" <td>6.0</td>\n",
|
||||||
" <td>21.0</td>\n",
|
" <td>23.0</td>\n",
|
||||||
" <td>47.0</td>\n",
|
" <td>47.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>3.0</td>\n",
|
" <td>3.0</td>\n",
|
||||||
@@ -241,15 +241,15 @@
|
|||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>4</th>\n",
|
" <th>4</th>\n",
|
||||||
" <td>63.0</td>\n",
|
" <td>62.0</td>\n",
|
||||||
" <td>4.0</td>\n",
|
" <td>4.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>11.0</td>\n",
|
" <td>13.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>8.0</td>\n",
|
" <td>8.0</td>\n",
|
||||||
" <td>21.0</td>\n",
|
" <td>30.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>4.0</td>\n",
|
" <td>5.0</td>\n",
|
||||||
" <td>3</td>\n",
|
" <td>3</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
@@ -267,12 +267,12 @@
|
|||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>209</th>\n",
|
" <th>209</th>\n",
|
||||||
" <td>17.0</td>\n",
|
" <td>13.0</td>\n",
|
||||||
" <td>22.0</td>\n",
|
" <td>33.0</td>\n",
|
||||||
" <td>14.0</td>\n",
|
" <td>11.0</td>\n",
|
||||||
" <td>15.0</td>\n",
|
" <td>19.0</td>\n",
|
||||||
" <td>26.0</td>\n",
|
" <td>23.0</td>\n",
|
||||||
" <td>21.0</td>\n",
|
" <td>27.0</td>\n",
|
||||||
" <td>4.0</td>\n",
|
" <td>4.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
@@ -280,12 +280,12 @@
|
|||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>210</th>\n",
|
" <th>210</th>\n",
|
||||||
" <td>14.0</td>\n",
|
" <td>11.0</td>\n",
|
||||||
" <td>10.0</td>\n",
|
" <td>19.0</td>\n",
|
||||||
" <td>15.0</td>\n",
|
" <td>18.0</td>\n",
|
||||||
" <td>27.0</td>\n",
|
" <td>29.0</td>\n",
|
||||||
" <td>25.0</td>\n",
|
" <td>23.0</td>\n",
|
||||||
" <td>30.0</td>\n",
|
" <td>33.0</td>\n",
|
||||||
" <td>3.0</td>\n",
|
" <td>3.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
@@ -293,39 +293,39 @@
|
|||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>211</th>\n",
|
" <th>211</th>\n",
|
||||||
" <td>19.0</td>\n",
|
" <td>14.0</td>\n",
|
||||||
" <td>33.0</td>\n",
|
" <td>41.0</td>\n",
|
||||||
" <td>15.0</td>\n",
|
" <td>18.0</td>\n",
|
||||||
" <td>17.0</td>\n",
|
" <td>20.0</td>\n",
|
||||||
" <td>36.0</td>\n",
|
" <td>34.0</td>\n",
|
||||||
" <td>12.0</td>\n",
|
" <td>14.0</td>\n",
|
||||||
" <td>3.0</td>\n",
|
" <td>3.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>4.0</td>\n",
|
" <td>5.0</td>\n",
|
||||||
" <td>3</td>\n",
|
" <td>3</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>212</th>\n",
|
" <th>212</th>\n",
|
||||||
" <td>23.0</td>\n",
|
" <td>20.0</td>\n",
|
||||||
" <td>5.0</td>\n",
|
|
||||||
" <td>8.0</td>\n",
|
" <td>8.0</td>\n",
|
||||||
" <td>21.0</td>\n",
|
" <td>8.0</td>\n",
|
||||||
" <td>43.0</td>\n",
|
" <td>23.0</td>\n",
|
||||||
" <td>30.0</td>\n",
|
" <td>42.0</td>\n",
|
||||||
" <td>9.0</td>\n",
|
" <td>33.0</td>\n",
|
||||||
|
" <td>11.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>3</td>\n",
|
" <td>3</td>\n",
|
||||||
" </tr>\n",
|
" </tr>\n",
|
||||||
" <tr>\n",
|
" <tr>\n",
|
||||||
" <th>213</th>\n",
|
" <th>213</th>\n",
|
||||||
" <td>44.0</td>\n",
|
" <td>43.0</td>\n",
|
||||||
" <td>38.0</td>\n",
|
" <td>46.0</td>\n",
|
||||||
" <td>6.0</td>\n",
|
" <td>6.0</td>\n",
|
||||||
" <td>21.0</td>\n",
|
" <td>23.0</td>\n",
|
||||||
" <td>25.0</td>\n",
|
" <td>23.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>10.0</td>\n",
|
" <td>15.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>0.0</td>\n",
|
" <td>0.0</td>\n",
|
||||||
" <td>2</td>\n",
|
" <td>2</td>\n",
|
||||||
@@ -337,17 +337,17 @@
|
|||||||
],
|
],
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
" RI Na Mg Al Si 'K' Ca Ba Fe Type\n",
|
" RI Na Mg Al Si 'K' Ca Ba Fe Type\n",
|
||||||
"0 31.0 8.0 15.0 13.0 38.0 26.0 9.0 0.0 0.0 0\n",
|
"0 30.0 14.0 16.0 18.0 38.0 32.0 14.0 0.0 0.0 0\n",
|
||||||
"1 23.0 3.0 15.0 19.0 36.0 19.0 9.0 0.0 0.0 1\n",
|
"1 17.0 3.0 18.0 21.0 34.0 24.0 10.0 0.0 0.0 1\n",
|
||||||
"2 31.0 17.0 15.0 20.0 24.0 21.0 7.0 0.0 0.0 0\n",
|
"2 30.0 24.0 15.0 22.0 22.0 27.0 6.0 0.0 0.0 0\n",
|
||||||
"3 3.0 42.0 6.0 21.0 47.0 0.0 3.0 0.0 0.0 2\n",
|
"3 3.0 51.0 6.0 23.0 47.0 0.0 3.0 0.0 0.0 2\n",
|
||||||
"4 63.0 4.0 0.0 11.0 0.0 8.0 21.0 0.0 4.0 3\n",
|
"4 62.0 4.0 0.0 13.0 0.0 8.0 30.0 0.0 5.0 3\n",
|
||||||
".. ... ... ... ... ... ... ... ... ... ...\n",
|
".. ... ... ... ... ... ... ... ... ... ...\n",
|
||||||
"209 17.0 22.0 14.0 15.0 26.0 21.0 4.0 0.0 0.0 1\n",
|
"209 13.0 33.0 11.0 19.0 23.0 27.0 4.0 0.0 0.0 1\n",
|
||||||
"210 14.0 10.0 15.0 27.0 25.0 30.0 3.0 0.0 0.0 3\n",
|
"210 11.0 19.0 18.0 29.0 23.0 33.0 3.0 0.0 0.0 3\n",
|
||||||
"211 19.0 33.0 15.0 17.0 36.0 12.0 3.0 0.0 4.0 3\n",
|
"211 14.0 41.0 18.0 20.0 34.0 14.0 3.0 0.0 5.0 3\n",
|
||||||
"212 23.0 5.0 8.0 21.0 43.0 30.0 9.0 0.0 0.0 3\n",
|
"212 20.0 8.0 8.0 23.0 42.0 33.0 11.0 0.0 0.0 3\n",
|
||||||
"213 44.0 38.0 6.0 21.0 25.0 0.0 10.0 0.0 0.0 2\n",
|
"213 43.0 46.0 6.0 23.0 23.0 0.0 15.0 0.0 0.0 2\n",
|
||||||
"\n",
|
"\n",
|
||||||
"[214 rows x 10 columns]"
|
"[214 rows x 10 columns]"
|
||||||
]
|
]
|
||||||
@@ -373,39 +373,317 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 17,
|
||||||
"id": "6a1aad95-370f-4854-ae9a-32205aff5d39",
|
"id": "2840a103-99fb-466f-ae75-45e11c1b9c5a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"for simple_init in [False, True]:\n",
|
"from sklearn.model_selection import cross_validate, StratifiedKFold, KFold, cross_val_score\n",
|
||||||
" model = TAN(simple_init=simple_init)\n",
|
"import numpy as np\n",
|
||||||
" for head in range(4):\n",
|
"n_folds = 5\n",
|
||||||
" model.fit(X, y, head=head, features=features, class_name=class_name)\n",
|
"score_name = \"accuracy\"\n",
|
||||||
" ypred = model.predict(X)\n",
|
"random_state=17\n",
|
||||||
" #model.plot(f\"simple_init={simple_init} head={head} score={model.predict(X)}\")"
|
"def validate_classifier(model, X, y, stratified, fit_params):\n",
|
||||||
|
" stratified_class = StratifiedKFold if stratified else KFold\n",
|
||||||
|
" kfold = stratified_class(shuffle=True, random_state=random_state, n_splits=n_folds)\n",
|
||||||
|
" #return cross_validate(model, X, y, cv=kfold, return_estimator=True, scoring=score_name)\n",
|
||||||
|
" return cross_val_score(model, X, y, fit_params=fit_params)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 20,
|
||||||
"id": "76905bf3",
|
"id": "6a1aad95-370f-4854-ae9a-32205aff5d39",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "7bdb666c5e5140e688141356958b362f",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
" 0%| | 0/43 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "0110029f56a4451cb877eeaae42e8e3f",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
" 0%| | 0/43 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "1e8130ed7d3a4f4da499dc481df369ab",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
" 0%| | 0/43 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "52adf5efe6874dcfa1e776c6a3c47f2c",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
" 0%| | 0/43 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "27fc7a5b95434c7c86ab2ddde212e006",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
" 0%| | 0/42 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n",
|
||||||
|
"/Users/rmontanana/Code/pgmpy/pgmpy/factors/discrete/DiscreteFactor.py:541: UserWarning: Found unknown state name. Trying to switch to using all state names as state numbers\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "AttributeError",
|
||||||
|
"evalue": "'TAN' object has no attribute 'model_'",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn [20], line 10\u001b[0m\n\u001b[1;32m 8\u001b[0m score \u001b[38;5;241m=\u001b[39m validate_classifier(model, X, y, stratified\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, fit_params\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mdict\u001b[39m(head\u001b[38;5;241m=\u001b[39mhead, features\u001b[38;5;241m=\u001b[39mfeatures, class_name\u001b[38;5;241m=\u001b[39mclass_name))\n\u001b[1;32m 9\u001b[0m \u001b[38;5;66;03m#model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score['test_score'])}\")\u001b[39;00m\n\u001b[0;32m---> 10\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplot\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msimple_init=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43msimple_init\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m head=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mhead\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m score=\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmean\u001b[49m\u001b[43m(\u001b[49m\u001b[43mscore\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Code/bayesclass/bayesclass/bayesclass.py:148\u001b[0m, in \u001b[0;36mTAN.plot\u001b[0;34m(self, title)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mplot\u001b[39m(\u001b[38;5;28mself\u001b[39m, title\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 147\u001b[0m nx\u001b[38;5;241m.\u001b[39mdraw_circular(\n\u001b[0;32m--> 148\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_\u001b[49m,\n\u001b[1;32m 149\u001b[0m with_labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 150\u001b[0m arrowsize\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m30\u001b[39m,\n\u001b[1;32m 151\u001b[0m node_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m800\u001b[39m,\n\u001b[1;32m 152\u001b[0m alpha\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.3\u001b[39m,\n\u001b[1;32m 153\u001b[0m font_weight\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbold\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 154\u001b[0m )\n\u001b[1;32m 155\u001b[0m plt\u001b[38;5;241m.\u001b[39mtitle(title)\n\u001b[1;32m 156\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n",
|
||||||
|
"\u001b[0;31mAttributeError\u001b[0m: 'TAN' object has no attribute 'model_'"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import warnings\n",
|
||||||
|
"from stree import Stree\n",
|
||||||
|
"warnings.filterwarnings('ignore')\n",
|
||||||
|
"for simple_init in [False, True]:\n",
|
||||||
|
" model = TAN(simple_init=simple_init)\n",
|
||||||
|
" for head in range(4):\n",
|
||||||
|
" #model.fit(X, y, head=head, features=features, class_name=class_name)\n",
|
||||||
|
" score = validate_classifier(model, X, y, stratified=False, fit_params=dict(head=head, features=features, class_name=class_name))\n",
|
||||||
|
" #model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score['test_score'])}\")\n",
|
||||||
|
" model.plot(f\"simple_init={simple_init} head={head} score={np.mean(score)}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 21,
|
||||||
|
"id": "c389ff1e-76d9-4c5b-9860-ea6d4752fac7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"(214, 9)"
|
"array([nan, nan, nan, nan, nan])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 5,
|
"execution_count": 21,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"X.shape\n"
|
"score"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 28,
|
||||||
|
"id": "9c58629f-000b-4d8c-8896-efd032f1090c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"b 10\n",
|
||||||
|
"c 9\n",
|
||||||
|
"d 8\n",
|
||||||
|
"e 7\n",
|
||||||
|
"a 6\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from queue import PriorityQueue\n",
|
||||||
|
"q = PriorityQueue()\n",
|
||||||
|
"lista = ['b', 'c', 'd', 'e', 'a']\n",
|
||||||
|
"for i, c in zip(lista, range(len(lista))):\n",
|
||||||
|
" print(i,10-c)\n",
|
||||||
|
" q.put(i,10-c)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 25,
|
||||||
|
"id": "e2a768c0-3e21-48f3-b118-25408122d01c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"a\n",
|
||||||
|
"b\n",
|
||||||
|
"c\n",
|
||||||
|
"d\n",
|
||||||
|
"e\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"while not q.empty():\n",
|
||||||
|
" print(q.get())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "96bb1acd-f450-4b9c-8f54-f020e23dfc14",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
Reference in New Issue
Block a user