mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-17 16:45:54 +00:00
First KDB implementation
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from .bayesclass import TAN
|
from .bayesclass import TAN, KDB
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
|
|
||||||
__author__ = "Ricardo Montañana Gómez"
|
__author__ = "Ricardo Montañana Gómez"
|
||||||
@@ -6,4 +6,4 @@ __copyright__ = "Copyright 2020-2023, Ricardo Montañana Gómez"
|
|||||||
__license__ = "MIT License"
|
__license__ = "MIT License"
|
||||||
__author_email__ = "ricardo.montanana@alu.uclm.es"
|
__author_email__ = "ricardo.montanana@alu.uclm.es"
|
||||||
|
|
||||||
__all__ = ["TAN", "__version__"]
|
__all__ = ["TAN", "KDB", "__version__"]
|
||||||
|
@@ -2,12 +2,12 @@
|
|||||||
This is a module to be used as a reference for building other modules
|
This is a module to be used as a reference for building other modules
|
||||||
"""
|
"""
|
||||||
import random
|
import random
|
||||||
from itertools import combinations
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn.base import ClassifierMixin, BaseEstimator
|
from sklearn.base import ClassifierMixin, BaseEstimator
|
||||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||||
from sklearn.utils.multiclass import unique_labels
|
from sklearn.utils.multiclass import unique_labels
|
||||||
|
from sklearn.feature_selection import mutual_info_classif
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pgmpy.estimators import TreeSearch, BayesianEstimator
|
from pgmpy.estimators import TreeSearch, BayesianEstimator
|
||||||
from pgmpy.models import BayesianNetwork
|
from pgmpy.models import BayesianNetwork
|
||||||
@@ -16,6 +16,10 @@ from ._version import __version__
|
|||||||
|
|
||||||
|
|
||||||
class BayesBase(BaseEstimator, ClassifierMixin):
|
class BayesBase(BaseEstimator, ClassifierMixin):
|
||||||
|
def __init__(self, random_state, show_progress):
|
||||||
|
self.random_state = random_state
|
||||||
|
self.show_progress = show_progress
|
||||||
|
|
||||||
def _more_tags(self):
|
def _more_tags(self):
|
||||||
return {
|
return {
|
||||||
"requires_positive_X": True,
|
"requires_positive_X": True,
|
||||||
@@ -85,34 +89,6 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
# Return the classifier
|
# Return the classifier
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _check_params_fit(self, X, y, kwargs):
|
|
||||||
"""Check the parameters passed to fit"""
|
|
||||||
# Check that X and y have correct shape
|
|
||||||
X, y = check_X_y(X, y)
|
|
||||||
# Store the classes seen during fit
|
|
||||||
self.classes_ = unique_labels(y)
|
|
||||||
# Default values
|
|
||||||
self.class_name_ = "class"
|
|
||||||
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
|
|
||||||
self.head_ = 0
|
|
||||||
expected_args = ["class_name", "features", "head"]
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if key in expected_args:
|
|
||||||
setattr(self, f"{key}_", value)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected argument: {key}")
|
|
||||||
if self.random_state is not None:
|
|
||||||
random.seed(self.random_state)
|
|
||||||
if self.head_ == "random":
|
|
||||||
self.head_ = random.randint(0, len(self.features_) - 1)
|
|
||||||
if len(self.features_) != X.shape[1]:
|
|
||||||
raise ValueError(
|
|
||||||
"Number of features does not match the number of columns in X"
|
|
||||||
)
|
|
||||||
if self.head_ is not None and self.head_ >= len(self.features_):
|
|
||||||
raise ValueError("Head index out of range")
|
|
||||||
return X, y
|
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
"""A reference implementation of a prediction for a classifier.
|
"""A reference implementation of a prediction for a classifier.
|
||||||
|
|
||||||
@@ -167,17 +143,28 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
dataset = pd.DataFrame(X, columns=self.features_, dtype="int16")
|
dataset = pd.DataFrame(X, columns=self.features_, dtype="int16")
|
||||||
return self.model_.predict(dataset).values.ravel()
|
return self.model_.predict(dataset).values.ravel()
|
||||||
|
|
||||||
|
def plot(self, title="", node_size=800):
|
||||||
|
nx.draw_circular(
|
||||||
|
self.model_,
|
||||||
|
with_labels=True,
|
||||||
|
arrowsize=20,
|
||||||
|
node_size=node_size,
|
||||||
|
alpha=0.3,
|
||||||
|
font_weight="bold",
|
||||||
|
)
|
||||||
|
plt.title(title)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
class TAN(BayesBase):
|
class TAN(BayesBase):
|
||||||
"""Tree Augmented Naive Bayes
|
"""Tree Augmented Naive Bayes
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
simple_init : bool, default=True
|
|
||||||
How to init the initial DAG. If True, only the first feature is used
|
|
||||||
as father of the other features.
|
|
||||||
random_state: int, default=None
|
random_state: int, default=None
|
||||||
Random state for reproducibility
|
Random state for reproducibility
|
||||||
|
show_progress: bool, default=False
|
||||||
|
used in pgmpy to show progress bars
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
@@ -201,51 +188,40 @@ class TAN(BayesBase):
|
|||||||
The actual classifier
|
The actual classifier
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, show_progress=False, random_state=None):
|
||||||
self, simple_init=True, show_progress=False, random_state=None
|
super().__init__(
|
||||||
):
|
show_progress=show_progress, random_state=random_state
|
||||||
self.simple_init = simple_init
|
)
|
||||||
self.show_progress = show_progress
|
|
||||||
self.random_state = random_state
|
|
||||||
|
|
||||||
def __initial_edges(self):
|
def _check_params_fit(self, X, y, kwargs):
|
||||||
"""As with the naive Bayes, in a TAN structure, the class has no
|
"""Check the parameters passed to fit"""
|
||||||
parents, while features must have the class as parent and are forced to
|
# Check that X and y have correct shape
|
||||||
have one other feature as parent too (except for one single feature,
|
X, y = check_X_y(X, y)
|
||||||
which has only the class as parent and is considered the root of the
|
# Store the classes seen during fit
|
||||||
features' tree)
|
self.classes_ = unique_labels(y)
|
||||||
Cassio P. de Campos, Giorgio Corani, Mauro Scanagatta, Marco Cuccu,
|
# Default values
|
||||||
Marco Zaffalon,
|
self.class_name_ = "class"
|
||||||
Learning extended tree augmented naive structures,
|
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
|
||||||
International Journal of Approximate Reasoning,
|
self.head_ = 0
|
||||||
|
expected_args = ["class_name", "features", "head"]
|
||||||
Returns
|
for key, value in kwargs.items():
|
||||||
-------
|
if key in expected_args:
|
||||||
List
|
setattr(self, f"{key}_", value)
|
||||||
List of edges
|
else:
|
||||||
"""
|
raise ValueError(f"Unexpected argument: {key}")
|
||||||
head = self.head_
|
if self.random_state is not None:
|
||||||
if self.simple_init:
|
random.seed(self.random_state)
|
||||||
first_node = self.features_[head]
|
if self.head_ == "random":
|
||||||
return [
|
self.head_ = random.randint(0, len(self.features_) - 1)
|
||||||
(first_node, feature)
|
if len(self.features_) != X.shape[1]:
|
||||||
for feature in self.features_
|
raise ValueError(
|
||||||
if feature != first_node
|
"Number of features does not match the number of columns in X"
|
||||||
]
|
)
|
||||||
# initialize a complete network with all edges starting from head
|
if self.head_ is not None and self.head_ >= len(self.features_):
|
||||||
reordered = [
|
raise ValueError("Head index out of range")
|
||||||
self.features_[idx % len(self.features_)]
|
return X, y
|
||||||
for idx in range(head, len(self.features_) + head)
|
|
||||||
]
|
|
||||||
return list(combinations(reordered, 2))
|
|
||||||
|
|
||||||
def _build(self):
|
def _build(self):
|
||||||
# Initialize a Naive Bayes model
|
|
||||||
net = [(self.class_name_, feature) for feature in self.features_]
|
|
||||||
self.model_ = BayesianNetwork(net)
|
|
||||||
# initialize a complete network with all edges
|
|
||||||
self.model_.add_edges_from(self.__initial_edges())
|
|
||||||
# learn graph structure
|
|
||||||
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
||||||
self.dag_ = est.estimate(
|
self.dag_ = est.estimate(
|
||||||
estimator_type="tan",
|
estimator_type="tan",
|
||||||
@@ -263,31 +239,103 @@ class TAN(BayesBase):
|
|||||||
prior_type="K2",
|
prior_type="K2",
|
||||||
)
|
)
|
||||||
|
|
||||||
def plot(self, title=""):
|
|
||||||
nx.draw_circular(
|
|
||||||
self.model_,
|
|
||||||
with_labels=True,
|
|
||||||
arrowsize=30,
|
|
||||||
node_size=800,
|
|
||||||
alpha=0.3,
|
|
||||||
font_weight="bold",
|
|
||||||
)
|
|
||||||
plt.title(title)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
class KDB(BayesBase):
|
||||||
class KDBayesClassifier(BayesBase):
|
def __init__(self, k, show_progress=False, random_state=None):
|
||||||
def __init__(self, k=3, random_state=None):
|
|
||||||
self.k = k
|
self.k = k
|
||||||
self.random_state = random_state
|
super().__init__(
|
||||||
|
show_progress=show_progress, random_state=random_state
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
def _check_params_fit(self, X, y, kwargs):
|
||||||
def version() -> str:
|
"""Check the parameters passed to fit"""
|
||||||
"""Return the version of the package."""
|
# Check that X and y have correct shape
|
||||||
return __version__
|
X, y = check_X_y(X, y)
|
||||||
|
# Store the classes seen during fit
|
||||||
|
self.classes_ = unique_labels(y)
|
||||||
|
# Default values
|
||||||
|
self.class_name_ = "class"
|
||||||
|
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
|
||||||
|
self.head_ = 0
|
||||||
|
expected_args = ["class_name", "features"]
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in expected_args:
|
||||||
|
setattr(self, f"{key}_", value)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected argument: {key}")
|
||||||
|
if self.random_state is not None:
|
||||||
|
random.seed(self.random_state)
|
||||||
|
if len(self.features_) != X.shape[1]:
|
||||||
|
raise ValueError(
|
||||||
|
"Number of features does not match the number of columns in X"
|
||||||
|
)
|
||||||
|
return X, y
|
||||||
|
|
||||||
def _build(self):
|
def _build(self):
|
||||||
pass
|
"""
|
||||||
|
1. For each feature Xi, compute mutual information, I(X;;C), where C is the class.
|
||||||
|
2. Compute class conditional mutual information I(Xi;XjIC), f or each pair of features Xi and Xj, where i#j.
|
||||||
|
3. Let the used variable list, S, be empty.
|
||||||
|
4. Let the Bayesian network being constructed, BN, begin with a single class node, C.
|
||||||
|
5. Repeat until S includes all domain features
|
||||||
|
5.1. Select feature Xmax which is not in S and has the largest value I(Xmax;C).
|
||||||
|
5.2. Add a node to BN representing Xmax.
|
||||||
|
5.3. Add an arc from C to Xmax in BN.
|
||||||
|
5.4. Add m =min(lSl,/c) arcs from m distinct features Xj in S with the highest value for I(Xmax;X,jC).
|
||||||
|
5.5. Add Xmax to S.
|
||||||
|
Compute the conditional probabilility infered by the structure of BN by using counts from DB, and output BN.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def add_m_edges(dag, idx, S_nodes, conditional_weights):
|
||||||
|
n_edges = min(self.k, len(S_nodes))
|
||||||
|
cond_w = conditional_weights.copy()
|
||||||
|
exit_cond = False
|
||||||
|
num = 0
|
||||||
|
while not exit_cond:
|
||||||
|
max_minfo = np.argmax(cond_w[idx, :])
|
||||||
|
try:
|
||||||
|
dag.add_edge(
|
||||||
|
self.features_[max_minfo], self.features_[idx]
|
||||||
|
)
|
||||||
|
num += 1
|
||||||
|
except ValueError:
|
||||||
|
# Loops are not allowed
|
||||||
|
pass
|
||||||
|
cond_w[idx, max_minfo] = -1
|
||||||
|
exit_cond = num == n_edges or np.all(cond_w[idx, :] <= 0)
|
||||||
|
|
||||||
|
# 1. get the mutual information between each feature and the class
|
||||||
|
mutual = mutual_info_classif(self.X_, self.y_, discrete_features=True)
|
||||||
|
# 2. symmetric matrix where each element represents I(X, Y| class_node)
|
||||||
|
conditional_weights = TreeSearch(
|
||||||
|
self.dataset_
|
||||||
|
)._get_conditional_weights(
|
||||||
|
self.dataset_, self.class_name_, show_progress=self.show_progress
|
||||||
|
)
|
||||||
|
# 3.
|
||||||
|
S_nodes = []
|
||||||
|
# 4.
|
||||||
|
dag = BayesianNetwork()
|
||||||
|
dag.add_node(self.class_name_) # , state_names=self.classes_)
|
||||||
|
# 5. 5.1
|
||||||
|
for idx in np.argsort(mutual):
|
||||||
|
# 5.2
|
||||||
|
feature = self.features_[idx]
|
||||||
|
dag.add_node(feature)
|
||||||
|
# 5.3
|
||||||
|
dag.add_edge(self.class_name_, feature)
|
||||||
|
# 5.4
|
||||||
|
add_m_edges(dag, idx, S_nodes, conditional_weights)
|
||||||
|
# 5.5
|
||||||
|
S_nodes.append(idx)
|
||||||
|
self.dag_ = dag
|
||||||
|
|
||||||
def _train(self):
|
def _train(self):
|
||||||
pass
|
self.model_ = BayesianNetwork(
|
||||||
|
self.dag_.edges(), show_progress=self.show_progress
|
||||||
|
)
|
||||||
|
self.model_.fit(
|
||||||
|
self.dataset_,
|
||||||
|
estimator=BayesianEstimator,
|
||||||
|
prior_type="K2",
|
||||||
|
)
|
||||||
|
BIN
bayesclass/tests/baseline_images/test_KDB/line_dashes_KDB.png
Normal file
BIN
bayesclass/tests/baseline_images/test_KDB/line_dashes_KDB.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 50 KiB |
BIN
bayesclass/tests/baseline_images/test_TAN/line_dashes_TAN.png
Normal file
BIN
bayesclass/tests/baseline_images/test_TAN/line_dashes_TAN.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 44 KiB |
Binary file not shown.
Before Width: | Height: | Size: 45 KiB |
92
bayesclass/tests/test_KDB.py
Normal file
92
bayesclass/tests/test_KDB.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.preprocessing import KBinsDiscretizer
|
||||||
|
from matplotlib.testing.decorators import image_comparison
|
||||||
|
from matplotlib.testing.conftest import mpl_test_settings
|
||||||
|
|
||||||
|
|
||||||
|
from bayesclass import KDB
|
||||||
|
from .._version import __version__
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def data():
|
||||||
|
X, y = load_iris(return_X_y=True)
|
||||||
|
enc = KBinsDiscretizer(encode="ordinal")
|
||||||
|
return enc.fit_transform(X), y
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def clf():
|
||||||
|
return KDB(k=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_KDB_default_hyperparameters(data, clf):
|
||||||
|
# Test default values of hyperparameters
|
||||||
|
assert not clf.show_progress
|
||||||
|
assert clf.random_state is None
|
||||||
|
clf = KDB(show_progress=True, random_state=17, k=3)
|
||||||
|
assert clf.show_progress
|
||||||
|
assert clf.random_state == 17
|
||||||
|
clf.fit(*data)
|
||||||
|
assert clf.class_name_ == "class"
|
||||||
|
assert clf.features_ == [
|
||||||
|
"feature_0",
|
||||||
|
"feature_1",
|
||||||
|
"feature_2",
|
||||||
|
"feature_3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_KDB_version(clf):
|
||||||
|
"""Check TAN version."""
|
||||||
|
assert __version__ == clf.version()
|
||||||
|
|
||||||
|
|
||||||
|
def test_KDB_nodes_leaves(clf):
|
||||||
|
assert clf.nodes_leaves() == (0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_KDB_classifier(data, clf):
|
||||||
|
clf.fit(*data)
|
||||||
|
attribs = ["classes_", "X_", "y_", "features_", "class_name_"]
|
||||||
|
for attr in attribs:
|
||||||
|
assert hasattr(clf, attr)
|
||||||
|
X = data[0]
|
||||||
|
y = data[1]
|
||||||
|
y_pred = clf.predict(X)
|
||||||
|
assert y_pred.shape == (X.shape[0],)
|
||||||
|
assert sum(y == y_pred) == 147
|
||||||
|
|
||||||
|
|
||||||
|
@image_comparison(
|
||||||
|
baseline_images=["line_dashes_KDB"], remove_text=True, extensions=["png"]
|
||||||
|
)
|
||||||
|
def test_KDB_plot(data, clf):
|
||||||
|
# mpl_test_settings will automatically clean these internal side effects
|
||||||
|
mpl_test_settings
|
||||||
|
dataset = load_iris(as_frame=True)
|
||||||
|
clf.fit(*data, features=dataset["feature_names"])
|
||||||
|
clf.plot("KDB Iris")
|
||||||
|
|
||||||
|
|
||||||
|
def test_KDB_wrong_num_features(data, clf):
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Number of features does not match the number of columns in X",
|
||||||
|
):
|
||||||
|
clf.fit(*data, features=["feature_1", "feature_2"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_KDB_wrong_hyperparam(data, clf):
|
||||||
|
with pytest.raises(ValueError, match="Unexpected argument: wrong_param"):
|
||||||
|
clf.fit(*data, wrong_param="wrong_param")
|
||||||
|
|
||||||
|
|
||||||
|
def test_KDB_error_size_predict(data, clf):
|
||||||
|
X, y = data
|
||||||
|
clf.fit(X, y)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
X_diff_size = np.ones((10, X.shape[1] + 1))
|
||||||
|
clf.predict(X_diff_size)
|
@@ -17,14 +17,16 @@ def data():
|
|||||||
return enc.fit_transform(X), y
|
return enc.fit_transform(X), y
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_default_hyperparameters(data):
|
@pytest.fixture
|
||||||
clf = TAN()
|
def clf():
|
||||||
|
return TAN()
|
||||||
|
|
||||||
|
|
||||||
|
def test_TAN_default_hyperparameters(data, clf):
|
||||||
# Test default values of hyperparameters
|
# Test default values of hyperparameters
|
||||||
assert clf.simple_init
|
|
||||||
assert not clf.show_progress
|
assert not clf.show_progress
|
||||||
assert clf.random_state is None
|
assert clf.random_state is None
|
||||||
clf = TAN(simple_init=True, show_progress=True, random_state=17)
|
clf = TAN(show_progress=True, random_state=17)
|
||||||
assert clf.simple_init
|
|
||||||
assert clf.show_progress
|
assert clf.show_progress
|
||||||
assert clf.random_state == 17
|
assert clf.random_state == 17
|
||||||
clf.fit(*data)
|
clf.fit(*data)
|
||||||
@@ -38,34 +40,26 @@ def test_TAN_default_hyperparameters(data):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_version():
|
def test_TAN_version(clf):
|
||||||
"""Check TAN version."""
|
"""Check TAN version."""
|
||||||
clf = TAN()
|
|
||||||
assert __version__ == clf.version()
|
assert __version__ == clf.version()
|
||||||
|
|
||||||
|
|
||||||
|
def test_TAN_nodes_leaves(clf):
|
||||||
|
assert clf.nodes_leaves() == (0, 0)
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_random_head(data):
|
def test_TAN_random_head(data):
|
||||||
clf = TAN(random_state=17)
|
clf = TAN(random_state=17)
|
||||||
clf.fit(*data, head="random")
|
clf.fit(*data, head="random")
|
||||||
assert clf.head_ == 3
|
assert clf.head_ == 3
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_dag_initializer(data):
|
def test_TAN_classifier(data, clf):
|
||||||
clf_not_simple = TAN(simple_init=False)
|
|
||||||
clf_simple = TAN(simple_init=True)
|
|
||||||
clf_not_simple.fit(*data, head=0)
|
|
||||||
clf_simple.fit(*data, head=0)
|
|
||||||
assert clf_simple.dag_.edges == clf_not_simple.dag_.edges
|
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_classifier(data):
|
|
||||||
clf = TAN()
|
|
||||||
|
|
||||||
clf.fit(*data)
|
clf.fit(*data)
|
||||||
attribs = ["classes_", "X_", "y_", "head_", "features_", "class_name_"]
|
attribs = ["classes_", "X_", "y_", "head_", "features_", "class_name_"]
|
||||||
for attr in attribs:
|
for attr in attribs:
|
||||||
assert hasattr(clf, attr)
|
assert hasattr(clf, attr)
|
||||||
|
|
||||||
X = data[0]
|
X = data[0]
|
||||||
y = data[1]
|
y = data[1]
|
||||||
y_pred = clf.predict(X)
|
y_pred = clf.predict(X)
|
||||||
@@ -74,40 +68,17 @@ def test_TAN_classifier(data):
|
|||||||
|
|
||||||
|
|
||||||
@image_comparison(
|
@image_comparison(
|
||||||
baseline_images=["line_dashes"], remove_text=True, extensions=["png"]
|
baseline_images=["line_dashes_TAN"], remove_text=True, extensions=["png"]
|
||||||
)
|
)
|
||||||
def test_TAN_plot(data):
|
def test_TAN_plot(data, clf):
|
||||||
# mpl_test_settings will automatically clean these internal side effects
|
# mpl_test_settings will automatically clean these internal side effects
|
||||||
mpl_test_settings
|
mpl_test_settings
|
||||||
clf = TAN()
|
|
||||||
dataset = load_iris(as_frame=True)
|
dataset = load_iris(as_frame=True)
|
||||||
clf.fit(*data, features=dataset["feature_names"], head=0)
|
clf.fit(*data, features=dataset["feature_names"], head=0)
|
||||||
clf.plot("TAN Iris head=0")
|
clf.plot("TAN Iris head=0")
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_classifier_simple_init(data):
|
def test_KDB_wrong_num_features(data, clf):
|
||||||
dataset = load_iris(as_frame=True)
|
|
||||||
features = dataset["feature_names"]
|
|
||||||
clf = TAN(simple_init=True)
|
|
||||||
clf.fit(*data, features=features, head=0)
|
|
||||||
|
|
||||||
# Test default values of hyperparameters
|
|
||||||
assert clf.simple_init
|
|
||||||
|
|
||||||
clf.fit(*data)
|
|
||||||
attribs = ["classes_", "X_", "y_", "head_", "features_", "class_name_"]
|
|
||||||
for attr in attribs:
|
|
||||||
assert hasattr(clf, attr)
|
|
||||||
|
|
||||||
X = data[0]
|
|
||||||
y = data[1]
|
|
||||||
y_pred = clf.predict(X)
|
|
||||||
assert y_pred.shape == (X.shape[0],)
|
|
||||||
assert sum(y == y_pred) == 147
|
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_wrong_num_features(data):
|
|
||||||
clf = TAN()
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match="Number of features does not match the number of columns in X",
|
match="Number of features does not match the number of columns in X",
|
||||||
@@ -115,21 +86,18 @@ def test_TAN_wrong_num_features(data):
|
|||||||
clf.fit(*data, features=["feature_1", "feature_2"])
|
clf.fit(*data, features=["feature_1", "feature_2"])
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_wrong_hyperparam(data):
|
def test_TAN_wrong_hyperparam(data, clf):
|
||||||
clf = TAN()
|
|
||||||
with pytest.raises(ValueError, match="Unexpected argument: wrong_param"):
|
with pytest.raises(ValueError, match="Unexpected argument: wrong_param"):
|
||||||
clf.fit(*data, wrong_param="wrong_param")
|
clf.fit(*data, wrong_param="wrong_param")
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_head_out_of_range(data):
|
def test_TAN_head_out_of_range(data, clf):
|
||||||
clf = TAN()
|
|
||||||
with pytest.raises(ValueError, match="Head index out of range"):
|
with pytest.raises(ValueError, match="Head index out of range"):
|
||||||
clf.fit(*data, head=4)
|
clf.fit(*data, head=4)
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_error_size_predict(data):
|
def test_TAN_error_size_predict(data, clf):
|
||||||
X, y = data
|
X, y = data
|
||||||
clf = TAN()
|
|
||||||
clf.fit(X, y)
|
clf.fit(X, y)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
X_diff_size = np.ones((10, X.shape[1] + 1))
|
X_diff_size = np.ones((10, X.shape[1] + 1))
|
Reference in New Issue
Block a user