mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-15 23:55:57 +00:00
Begin AODENew with tests
This commit is contained in:
@@ -382,6 +382,18 @@ class KDB(BayesBase):
|
||||
self.dag_ = dag
|
||||
|
||||
|
||||
def build_spode(features, class_name):
|
||||
"""Build SPODE estimators (Super Parent One Dependent Estimator)"""
|
||||
class_edges = [(class_name, f) for f in features]
|
||||
for idx in range(len(features)):
|
||||
feature_edges = [
|
||||
(features[idx], f) for f in features if f != features[idx]
|
||||
]
|
||||
feature_edges.extend(class_edges)
|
||||
model = BayesianNetwork(feature_edges, show_progress=False)
|
||||
yield model
|
||||
|
||||
|
||||
class AODE(BayesBase, BaseEnsemble):
|
||||
def __init__(self, show_progress=False, random_state=None):
|
||||
super().__init__(
|
||||
@@ -416,20 +428,9 @@ class AODE(BayesBase, BaseEnsemble):
|
||||
self.dag_ = None
|
||||
|
||||
def _train(self, kwargs):
|
||||
"""Build SPODE estimators (Super Parent One Dependent Estimator)"""
|
||||
self.models_ = []
|
||||
class_edges = [(self.class_name_, f) for f in self.feature_names_in_]
|
||||
states = dict(state_names=kwargs.pop("state_names", []))
|
||||
for idx in range(self.n_features_in_):
|
||||
feature_edges = [
|
||||
(self.feature_names_in_[idx], f)
|
||||
for f in self.feature_names_in_
|
||||
if f != self.feature_names_in_[idx]
|
||||
]
|
||||
feature_edges.extend(class_edges)
|
||||
model = BayesianNetwork(
|
||||
feature_edges, show_progress=self.show_progress
|
||||
)
|
||||
for model in build_spode(self.feature_names_in_, self.class_name_):
|
||||
model.fit(
|
||||
self.dataset_,
|
||||
estimator=BayesianEstimator,
|
||||
@@ -510,24 +511,38 @@ class KDBNew(KDB):
|
||||
|
||||
|
||||
class AODENew(AODE):
|
||||
def fit(self, X, y, **kwargs):
|
||||
self.estimator = Proposal(self)
|
||||
return self.estimator.fit(X, y, **kwargs)
|
||||
def _train(self, kwargs):
|
||||
self.estimators_ = []
|
||||
states = dict(state_names=kwargs.pop("state_names", []))
|
||||
kwargs["states"] = states
|
||||
for model in build_spode(self.feature_names_in_, self.class_name_):
|
||||
estimator = Proposal(model)
|
||||
self.estimators_.append(estimator.fit(self.X_, self.y_, **kwargs))
|
||||
return self
|
||||
|
||||
def predict(self, X: np.ndarray) -> np.ndarray:
|
||||
check_is_fitted(self, ["X_", "y_", "fitted_"])
|
||||
# Input validation
|
||||
X = check_array(X)
|
||||
n_samples = X.shape[0]
|
||||
n_estimators = len(self.models_)
|
||||
n_estimators = len(self.estimators_)
|
||||
result = np.empty((n_samples, n_estimators))
|
||||
dataset = pd.DataFrame(
|
||||
X, columns=self.feature_names_in_, dtype=np.int32
|
||||
)
|
||||
for index, model in enumerate(self.models_):
|
||||
result[:, index] = model.predict(dataset).values.ravel()
|
||||
for index, model in enumerate(self.estimators_):
|
||||
result[:, index] = model.predict(X).values.ravel()
|
||||
return mode(result, axis=1, keepdims=False).mode.ravel()
|
||||
|
||||
@property
|
||||
def states_(self):
|
||||
if hasattr(self, "fitted_"):
|
||||
return sum(
|
||||
[
|
||||
len(item)
|
||||
for model in self.models_
|
||||
for _, item in model.states.items()
|
||||
]
|
||||
) / len(self.models_)
|
||||
return 0
|
||||
|
||||
|
||||
class Proposal:
|
||||
def __init__(self, estimator):
|
||||
@@ -557,12 +572,12 @@ class Proposal:
|
||||
if upgraded:
|
||||
kwargs = self.update_kwargs(y, kwargs)
|
||||
super(self.class_type, self.estimator).fit(self.Xd, y, **kwargs)
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
self.check_integrity("predict", self.discretizer.transform(X))
|
||||
return super(self.class_type, self.estimator).predict(
|
||||
self.discretizer.transform(X)
|
||||
)
|
||||
Xd = self.discretizer.transform(X)
|
||||
self.check_integrity("predict", Xd)
|
||||
return super(self.class_type, self.estimator).predict(Xd)
|
||||
|
||||
def update_kwargs(self, y, kwargs):
|
||||
features = (
|
||||
|
@@ -19,16 +19,16 @@ def data():
|
||||
|
||||
@pytest.fixture
|
||||
def clf():
|
||||
return AODE()
|
||||
return AODE(random_state=17)
|
||||
|
||||
|
||||
def test_AODE_default_hyperparameters(data, clf):
|
||||
# Test default values of hyperparameters
|
||||
assert not clf.show_progress
|
||||
assert clf.random_state is None
|
||||
clf = AODE(show_progress=True, random_state=17)
|
||||
assert clf.show_progress
|
||||
assert clf.random_state == 17
|
||||
clf = AODE(show_progress=True)
|
||||
assert clf.show_progress
|
||||
assert clf.random_state is None
|
||||
clf.fit(*data)
|
||||
assert clf.class_name_ == "class"
|
||||
assert clf.feature_names_in_ == [
|
||||
@@ -63,7 +63,6 @@ def test_AODE_nodes_edges(clf, data):
|
||||
|
||||
def test_AODE_states(clf, data):
|
||||
assert clf.states_ == 0
|
||||
clf = AODE(random_state=17)
|
||||
clf.fit(*data)
|
||||
assert clf.states_ == 23
|
||||
assert clf.depth_ == clf.states_
|
||||
|
108
bayesclass/tests/test_AODENew.py
Normal file
108
bayesclass/tests/test_AODENew.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import KBinsDiscretizer
|
||||
from matplotlib.testing.decorators import image_comparison
|
||||
from matplotlib.testing.conftest import mpl_test_settings
|
||||
|
||||
|
||||
from bayesclass.clfs import AODENew
|
||||
from .._version import __version__
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data():
|
||||
X, y = load_iris(return_X_y=True)
|
||||
enc = KBinsDiscretizer(encode="ordinal")
|
||||
return enc.fit_transform(X), y
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clf():
|
||||
return AODENew(random_state=17)
|
||||
|
||||
|
||||
def test_AODENew_default_hyperparameters(data, clf):
|
||||
# Test default values of hyperparameters
|
||||
assert not clf.show_progress
|
||||
assert clf.random_state == 17
|
||||
clf = AODENew(show_progress=True)
|
||||
assert clf.show_progress
|
||||
assert clf.random_state is None
|
||||
clf.fit(*data)
|
||||
assert clf.class_name_ == "class"
|
||||
assert clf.feature_names_in_ == [
|
||||
"feature_0",
|
||||
"feature_1",
|
||||
"feature_2",
|
||||
"feature_3",
|
||||
]
|
||||
|
||||
|
||||
@image_comparison(
|
||||
baseline_images=["line_dashes_AODE"], remove_text=True, extensions=["png"]
|
||||
)
|
||||
def test_AODENew_plot(data, clf):
|
||||
# mpl_test_settings will automatically clean these internal side effects
|
||||
mpl_test_settings
|
||||
dataset = load_iris(as_frame=True)
|
||||
clf.fit(*data, features=dataset["feature_names"])
|
||||
clf.plot("AODE Iris")
|
||||
|
||||
|
||||
def test_AODENew_version(clf):
|
||||
"""Check AODE version."""
|
||||
assert __version__ == clf.version()
|
||||
|
||||
|
||||
def test_AODENew_nodes_edges(clf, data):
|
||||
assert clf.nodes_edges() == (0, 0)
|
||||
clf.fit(*data)
|
||||
assert clf.nodes_leaves() == (20, 28)
|
||||
|
||||
|
||||
def test_AODENew_states(clf, data):
|
||||
assert clf.states_ == 0
|
||||
clf.fit(*data)
|
||||
assert clf.states_ == 23
|
||||
assert clf.depth_ == clf.states_
|
||||
|
||||
|
||||
def test_AODENew_classifier(data, clf):
|
||||
clf.fit(*data)
|
||||
attribs = [
|
||||
"classes_",
|
||||
"X_",
|
||||
"y_",
|
||||
"feature_names_in_",
|
||||
"class_name_",
|
||||
"n_features_in_",
|
||||
]
|
||||
for attr in attribs:
|
||||
assert hasattr(clf, attr)
|
||||
X = data[0]
|
||||
y = data[1]
|
||||
y_pred = clf.predict(X)
|
||||
assert y_pred.shape == (X.shape[0],)
|
||||
assert sum(y == y_pred) == 147
|
||||
|
||||
|
||||
def test_AODENew_wrong_num_features(data, clf):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Number of features does not match the number of columns in X",
|
||||
):
|
||||
clf.fit(*data, features=["feature_1", "feature_2"])
|
||||
|
||||
|
||||
def test_AODENew_wrong_hyperparam(data, clf):
|
||||
with pytest.raises(ValueError, match="Unexpected argument: wrong_param"):
|
||||
clf.fit(*data, wrong_param="wrong_param")
|
||||
|
||||
|
||||
def test_AODENew_error_size_predict(data, clf):
|
||||
X, y = data
|
||||
clf.fit(X, y)
|
||||
with pytest.raises(ValueError):
|
||||
X_diff_size = np.ones((10, X.shape[1] + 1))
|
||||
clf.predict(X_diff_size)
|
Reference in New Issue
Block a user