mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-18 17:15:53 +00:00
Refactor AODE & AODENew
This commit is contained in:
@@ -3,7 +3,7 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from scipy.stats import mode
|
from scipy.stats import mode
|
||||||
from sklearn.base import ClassifierMixin, BaseEstimator
|
from sklearn.base import clone, ClassifierMixin, BaseEstimator
|
||||||
from sklearn.ensemble import BaseEnsemble
|
from sklearn.ensemble import BaseEnsemble
|
||||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||||
from sklearn.utils.multiclass import unique_labels
|
from sklearn.utils.multiclass import unique_labels
|
||||||
@@ -16,6 +16,10 @@ from fimdlp.mdlp import FImdlp
|
|||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
|
|
||||||
|
|
||||||
|
def default_feature_names(num_features):
|
||||||
|
return [f"feature_{i}" for i in range(num_features)]
|
||||||
|
|
||||||
|
|
||||||
class BayesBase(BaseEstimator, ClassifierMixin):
|
class BayesBase(BaseEstimator, ClassifierMixin):
|
||||||
def __init__(self, random_state, show_progress):
|
def __init__(self, random_state, show_progress):
|
||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
@@ -39,10 +43,6 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
return len(self.dag_), len(self.dag_.edges())
|
return len(self.dag_), len(self.dag_.edges())
|
||||||
return 0, 0
|
return 0, 0
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def default_feature_names(num_features):
|
|
||||||
return [f"feature_{i}" for i in range(num_features)]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_class_name():
|
def default_class_name():
|
||||||
return "class"
|
return "class"
|
||||||
@@ -57,7 +57,7 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
self.n_classes_ = self.classes_.shape[0]
|
self.n_classes_ = self.classes_.shape[0]
|
||||||
# Default values
|
# Default values
|
||||||
self.class_name_ = self.default_class_name()
|
self.class_name_ = self.default_class_name()
|
||||||
self.features_ = self.default_feature_names(X.shape[1])
|
self.features_ = default_feature_names(X.shape[1])
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key in expected_args:
|
if key in expected_args:
|
||||||
setattr(self, f"{key}_", value)
|
setattr(self, f"{key}_", value)
|
||||||
@@ -139,6 +139,9 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
# Return the classifier
|
# Return the classifier
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _build(self):
|
||||||
|
...
|
||||||
|
|
||||||
def _train(self, kwargs):
|
def _train(self, kwargs):
|
||||||
self.model_ = BayesianNetwork(
|
self.model_ = BayesianNetwork(
|
||||||
self.dag_.edges(), show_progress=self.show_progress
|
self.dag_.edges(), show_progress=self.show_progress
|
||||||
@@ -394,97 +397,97 @@ def build_spodes(features, class_name):
|
|||||||
yield model
|
yield model
|
||||||
|
|
||||||
|
|
||||||
class AODE(ClassifierMixin, BaseEnsemble):
|
class SPODE(BayesBase):
|
||||||
def __init__(self, show_progress=False, random_state=None):
|
|
||||||
self.base_model = BayesBase(
|
|
||||||
show_progress=show_progress, random_state=random_state
|
|
||||||
)
|
|
||||||
self.show_progress = show_progress
|
|
||||||
self.random_state = random_state
|
|
||||||
|
|
||||||
def _check_params(self, X, y, kwargs):
|
def _check_params(self, X, y, kwargs):
|
||||||
expected_args = ["class_name", "features", "state_names"]
|
expected_args = ["class_name", "features", "state_names"]
|
||||||
return self.base_model._check_params_fit(X, y, expected_args, kwargs)
|
return self._check_params_fit(X, y, expected_args, kwargs)
|
||||||
|
|
||||||
def nodes_edges(self):
|
|
||||||
nodes = 0
|
|
||||||
edges = 0
|
|
||||||
if hasattr(self, "fitted_"):
|
|
||||||
nodes = sum([len(x) for x in self.models_])
|
|
||||||
edges = sum([len(x.edges()) for x in self.models_])
|
|
||||||
return nodes, edges
|
|
||||||
|
|
||||||
def version(self):
|
class AODE(ClassifierMixin, BaseEnsemble):
|
||||||
return self.base_model.version()
|
def __init__(
|
||||||
|
self,
|
||||||
|
show_progress=False,
|
||||||
|
random_state=None,
|
||||||
|
estimator=None,
|
||||||
|
):
|
||||||
|
self.show_progress = show_progress
|
||||||
|
self.random_state = random_state
|
||||||
|
super().__init__(estimator=estimator)
|
||||||
|
|
||||||
|
def _validate_estimator(self) -> None:
|
||||||
|
"""Check the estimator and set the estimator_ attribute."""
|
||||||
|
super()._validate_estimator(
|
||||||
|
default=SPODE(
|
||||||
|
random_state=self.random_state,
|
||||||
|
show_progress=self.show_progress,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def fit(self, X, y, **kwargs):
|
def fit(self, X, y, **kwargs):
|
||||||
X_, y_ = self._check_params(X, y, kwargs)
|
self.n_features_in_ = X.shape[1]
|
||||||
self.class_name_ = self.base_model.class_name_
|
self.feature_names_in_ = kwargs.get(
|
||||||
self.feature_names_in_ = self.base_model.feature_names_in_
|
"features", default_feature_names(self.n_features_in_)
|
||||||
self.classes_ = self.base_model.classes_
|
|
||||||
self.n_features_in_ = self.base_model.n_features_in_
|
|
||||||
# Store the information needed to build the model
|
|
||||||
self.X_ = X_
|
|
||||||
self.y_ = y_
|
|
||||||
self.dataset_ = pd.DataFrame(
|
|
||||||
self.X_, columns=self.feature_names_in_, dtype=np.int32
|
|
||||||
)
|
)
|
||||||
self.dataset_[self.class_name_] = self.y_
|
self.class_name_ = kwargs.get("class_name", "class")
|
||||||
# Train the model
|
# build estimator
|
||||||
|
self._validate_estimator()
|
||||||
|
self.X_ = X
|
||||||
|
self.y_ = y
|
||||||
|
self.estimators_ = []
|
||||||
self._train(kwargs)
|
self._train(kwargs)
|
||||||
self.fitted_ = True
|
|
||||||
# To keep compatiblity with the benchmark platform
|
# To keep compatiblity with the benchmark platform
|
||||||
|
self.fitted_ = True
|
||||||
self.nodes_leaves = self.nodes_edges
|
self.nodes_leaves = self.nodes_edges
|
||||||
# Return the classifier
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def _train(self, kwargs):
|
||||||
|
for dag in build_spodes(self.feature_names_in_, self.class_name_):
|
||||||
|
estimator = clone(self.estimator_)
|
||||||
|
estimator.dag_ = estimator.model_ = dag
|
||||||
|
estimator.fit(self.X_, self.y_, **kwargs)
|
||||||
|
self.estimators_.append(estimator)
|
||||||
|
|
||||||
|
def predict(self, X: np.ndarray) -> np.ndarray:
|
||||||
|
n_samples = X.shape[0]
|
||||||
|
n_estimators = len(self.estimators_)
|
||||||
|
result = np.empty((n_samples, n_estimators))
|
||||||
|
for index, estimator in enumerate(self.estimators_):
|
||||||
|
result[:, index] = estimator.predict(X)
|
||||||
|
return mode(result, axis=1, keepdims=False).mode.ravel()
|
||||||
|
|
||||||
|
def version(self):
|
||||||
|
if hasattr(self, "fitted_"):
|
||||||
|
return self.estimator_.version()
|
||||||
|
return SPODE(None, False).version()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def states_(self):
|
def states_(self):
|
||||||
if hasattr(self, "fitted_"):
|
if hasattr(self, "fitted_"):
|
||||||
return sum(
|
return sum(
|
||||||
[
|
[
|
||||||
len(item)
|
len(item)
|
||||||
for model in self.models_
|
for model in self.estimators_
|
||||||
for _, item in model.states.items()
|
for _, item in model.model_.states.items()
|
||||||
]
|
]
|
||||||
) / len(self.models_)
|
) / len(self.estimators_)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def depth_(self):
|
def depth_(self):
|
||||||
return self.states_
|
return self.states_
|
||||||
|
|
||||||
def _train(self, kwargs):
|
def nodes_edges(self):
|
||||||
self.models_ = []
|
nodes = 0
|
||||||
states = dict(state_names=kwargs.pop("state_names", []))
|
edges = 0
|
||||||
for model in build_spodes(self.feature_names_in_, self.class_name_):
|
if hasattr(self, "fitted_"):
|
||||||
model.fit(
|
nodes = sum([len(x.dag_) for x in self.estimators_])
|
||||||
self.dataset_,
|
edges = sum([len(x.dag_.edges()) for x in self.estimators_])
|
||||||
estimator=BayesianEstimator,
|
return nodes, edges
|
||||||
prior_type="K2",
|
|
||||||
**states,
|
|
||||||
)
|
|
||||||
self.models_.append(model)
|
|
||||||
|
|
||||||
def plot(self, title=""):
|
def plot(self, title=""):
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
for idx, model in enumerate(self.models_):
|
for idx, model in enumerate(self.estimators_):
|
||||||
self.base_model.model_ = model
|
model.plot(title=f"{idx} {title}")
|
||||||
self.base_model.plot(title=f"{idx} {title}")
|
|
||||||
|
|
||||||
def predict(self, X: np.ndarray) -> np.ndarray:
|
|
||||||
check_is_fitted(self, ["X_", "y_", "fitted_"])
|
|
||||||
# Input validation
|
|
||||||
X = check_array(X)
|
|
||||||
n_samples = X.shape[0]
|
|
||||||
n_estimators = len(self.models_)
|
|
||||||
result = np.empty((n_samples, n_estimators))
|
|
||||||
dataset = pd.DataFrame(
|
|
||||||
X, columns=self.feature_names_in_, dtype=np.int32
|
|
||||||
)
|
|
||||||
for index, model in enumerate(self.models_):
|
|
||||||
result[:, index] = model.predict(dataset).values.ravel()
|
|
||||||
return mode(result, axis=1, keepdims=False).mode.ravel()
|
|
||||||
|
|
||||||
|
|
||||||
class TANNew(TAN):
|
class TANNew(TAN):
|
||||||
@@ -504,11 +507,12 @@ class TANNew(TAN):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def fit(self, X, y, **kwargs):
|
def fit(self, X, y, **kwargs):
|
||||||
self.estimator = Proposal(self)
|
self.estimator_ = Proposal(self)
|
||||||
return self.estimator.fit(X, y, **kwargs)
|
self.estimator_.fit(X, y, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
return self.estimator.predict(X)
|
return self.estimator_.predict(X)
|
||||||
|
|
||||||
|
|
||||||
class KDBNew(KDB):
|
class KDBNew(KDB):
|
||||||
@@ -529,15 +533,17 @@ class KDBNew(KDB):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def fit(self, X, y, **kwargs):
|
def fit(self, X, y, **kwargs):
|
||||||
self.estimator = Proposal(self)
|
self.estimator_ = Proposal(self)
|
||||||
return self.estimator.fit(X, y, **kwargs)
|
self.estimator_.fit(X, y, **kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
return self.estimator.predict(X)
|
return self.estimator_.predict(X)
|
||||||
|
|
||||||
|
|
||||||
class SpodeNew(BayesBase):
|
class SPODENew(SPODE):
|
||||||
"""This class implements a classifier for the SPODE algorithm similar to TANNew and KDBNew"""
|
"""This class implements a classifier for the SPODE algorithm similar to
|
||||||
|
TANNew and KDBNew"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -554,13 +560,6 @@ class SpodeNew(BayesBase):
|
|||||||
self.discretizer_length = discretizer_length
|
self.discretizer_length = discretizer_length
|
||||||
self.discretizer_cuts = discretizer_cuts
|
self.discretizer_cuts = discretizer_cuts
|
||||||
|
|
||||||
def _check_params(self, X, y, kwargs):
|
|
||||||
expected_args = ["class_name", "features", "state_names"]
|
|
||||||
return self._check_params_fit(X, y, expected_args, kwargs)
|
|
||||||
|
|
||||||
def _build(self):
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class AODENew(AODE):
|
class AODENew(AODE):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -575,31 +574,32 @@ class AODENew(AODE):
|
|||||||
self.discretizer_length = discretizer_length
|
self.discretizer_length = discretizer_length
|
||||||
self.discretizer_cuts = discretizer_cuts
|
self.discretizer_cuts = discretizer_cuts
|
||||||
super().__init__(
|
super().__init__(
|
||||||
show_progress=show_progress, random_state=random_state
|
random_state=random_state,
|
||||||
|
show_progress=show_progress,
|
||||||
|
estimator=Proposal(
|
||||||
|
SPODENew(
|
||||||
|
random_state=random_state,
|
||||||
|
show_progress=show_progress,
|
||||||
|
discretizer_depth=discretizer_depth,
|
||||||
|
discretizer_length=discretizer_length,
|
||||||
|
discretizer_cuts=discretizer_cuts,
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _train(self, kwargs):
|
def _train(self, kwargs):
|
||||||
self.models_ = []
|
for dag in build_spodes(self.feature_names_in_, self.class_name_):
|
||||||
for model in build_spodes(self.feature_names_in_, self.class_name_):
|
proposal = clone(self.estimator_)
|
||||||
spode = SpodeNew(
|
proposal.estimator.dag_ = proposal.estimator.model_ = dag
|
||||||
random_state=self.random_state,
|
self.estimators_.append(proposal.fit(self.X_, self.y_, **kwargs))
|
||||||
show_progress=self.show_progress,
|
self.n_estimators_ = len(self.estimators_)
|
||||||
discretizer_cuts=self.discretizer_cuts,
|
|
||||||
discretizer_depth=self.discretizer_depth,
|
|
||||||
discretizer_length=self.discretizer_length,
|
|
||||||
)
|
|
||||||
spode.dag_ = model
|
|
||||||
estimator = Proposal(spode)
|
|
||||||
self.models_.append(estimator.fit(self.X_, self.y_, **kwargs))
|
|
||||||
|
|
||||||
def predict(self, X: np.ndarray) -> np.ndarray:
|
def predict(self, X: np.ndarray) -> np.ndarray:
|
||||||
check_is_fitted(self, ["X_", "y_", "fitted_"])
|
check_is_fitted(self, ["X_", "y_", "fitted_"])
|
||||||
# Input validation
|
# Input validation
|
||||||
X = check_array(X)
|
X = check_array(X)
|
||||||
n_samples = X.shape[0]
|
result = np.empty((X.shape[0], self.n_estimators_))
|
||||||
n_estimators = len(self.models_)
|
for index, model in enumerate(self.estimators_):
|
||||||
result = np.empty((n_samples, n_estimators))
|
|
||||||
for index, model in enumerate(self.models_):
|
|
||||||
result[:, index] = model.predict(X)
|
result[:, index] = model.predict(X)
|
||||||
return mode(result, axis=1, keepdims=False).mode.ravel()
|
return mode(result, axis=1, keepdims=False).mode.ravel()
|
||||||
|
|
||||||
@@ -607,26 +607,40 @@ class AODENew(AODE):
|
|||||||
def states_(self):
|
def states_(self):
|
||||||
if hasattr(self, "fitted_"):
|
if hasattr(self, "fitted_"):
|
||||||
return sum(
|
return sum(
|
||||||
[model.estimator.states_ for model in self.models_]
|
[
|
||||||
) / len(self.models_)
|
len(item)
|
||||||
|
for model in self.estimators_
|
||||||
|
for _, item in model.estimator.model_.states.items()
|
||||||
|
]
|
||||||
|
) / len(self.estimators_)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def depth_(self):
|
||||||
|
return self.states_
|
||||||
|
|
||||||
def nodes_edges(self):
|
def nodes_edges(self):
|
||||||
nodes = [0]
|
nodes = 0
|
||||||
edges = [0]
|
edges = 0
|
||||||
if hasattr(self, "fitted_"):
|
if hasattr(self, "fitted_"):
|
||||||
nodes, edges = zip(
|
nodes = sum([len(x.estimator.dag_) for x in self.estimators_])
|
||||||
*[model.estimator.nodes_edges() for model in self.models_]
|
edges = sum(
|
||||||
|
[len(x.estimator.dag_.edges()) for x in self.estimators_]
|
||||||
)
|
)
|
||||||
return sum(nodes), sum(edges)
|
return nodes, edges
|
||||||
|
|
||||||
def plot(self, title=""):
|
def plot(self, title=""):
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
for idx, model in enumerate(self.models_):
|
for idx, model in enumerate(self.estimators_):
|
||||||
model.estimator.plot(title=f"{idx} {title}")
|
model.estimator.plot(title=f"{idx} {title}")
|
||||||
|
|
||||||
|
def version(self):
|
||||||
|
if hasattr(self, "fitted_"):
|
||||||
|
return self.estimator_.estimator.version()
|
||||||
|
return SPODENew(None, False).version()
|
||||||
|
|
||||||
class Proposal:
|
|
||||||
|
class Proposal(BaseEstimator):
|
||||||
def __init__(self, estimator):
|
def __init__(self, estimator):
|
||||||
self.estimator = estimator
|
self.estimator = estimator
|
||||||
self.class_type = estimator.__class__
|
self.class_type = estimator.__class__
|
||||||
@@ -635,13 +649,13 @@ class Proposal:
|
|||||||
# Check parameters
|
# Check parameters
|
||||||
self.estimator._check_params(X, y, kwargs)
|
self.estimator._check_params(X, y, kwargs)
|
||||||
# Discretize train data
|
# Discretize train data
|
||||||
self.discretizer = FImdlp(
|
self.discretizer_ = FImdlp(
|
||||||
n_jobs=1,
|
n_jobs=1,
|
||||||
max_depth=self.estimator.discretizer_depth,
|
max_depth=self.estimator.discretizer_depth,
|
||||||
min_length=self.estimator.discretizer_length,
|
min_length=self.estimator.discretizer_length,
|
||||||
max_cuts=self.estimator.discretizer_cuts,
|
max_cuts=self.estimator.discretizer_cuts,
|
||||||
)
|
)
|
||||||
self.Xd = self.discretizer.fit_transform(X, y)
|
self.Xd = self.discretizer_.fit_transform(X, y)
|
||||||
kwargs = self.update_kwargs(y, kwargs)
|
kwargs = self.update_kwargs(y, kwargs)
|
||||||
# Build the model
|
# Build the model
|
||||||
super(self.class_type, self.estimator).fit(self.Xd, y, **kwargs)
|
super(self.class_type, self.estimator).fit(self.Xd, y, **kwargs)
|
||||||
@@ -662,7 +676,7 @@ class Proposal:
|
|||||||
check_is_fitted(self, ["fitted_"])
|
check_is_fitted(self, ["fitted_"])
|
||||||
# Input validation
|
# Input validation
|
||||||
X = check_array(X)
|
X = check_array(X)
|
||||||
Xd = self.discretizer.transform(X)
|
Xd = self.discretizer_.transform(X)
|
||||||
# self.check_integrity("predict", Xd)
|
# self.check_integrity("predict", Xd)
|
||||||
return super(self.class_type, self.estimator).predict(Xd)
|
return super(self.class_type, self.estimator).predict(Xd)
|
||||||
|
|
||||||
@@ -670,10 +684,10 @@ class Proposal:
|
|||||||
features = (
|
features = (
|
||||||
kwargs["features"]
|
kwargs["features"]
|
||||||
if "features" in kwargs
|
if "features" in kwargs
|
||||||
else self.estimator.default_feature_names(self.Xd.shape[1])
|
else default_feature_names(self.Xd.shape[1])
|
||||||
)
|
)
|
||||||
states = {
|
states = {
|
||||||
features[i]: self.discretizer.get_states_feature(i)
|
features[i]: self.discretizer_.get_states_feature(i)
|
||||||
for i in range(self.Xd.shape[1])
|
for i in range(self.Xd.shape[1])
|
||||||
}
|
}
|
||||||
class_name = (
|
class_name = (
|
||||||
@@ -706,7 +720,7 @@ class Proposal:
|
|||||||
# Get the fathers indices
|
# Get the fathers indices
|
||||||
features = [self.idx_features_[f] for f in fathers]
|
features = [self.idx_features_[f] for f in fathers]
|
||||||
# Update the discretization of the feature
|
# Update the discretization of the feature
|
||||||
res[:, idx] = self.discretizer.join_fit(
|
res[:, idx] = self.discretizer_.join_fit(
|
||||||
target=idx, features=features, data=self.Xd
|
target=idx, features=features, data=self.Xd
|
||||||
)
|
)
|
||||||
# print(self.discretizer.y_join[:5])
|
# print(self.discretizer.y_join[:5])
|
||||||
|
@@ -50,9 +50,12 @@ def test_AODE_plot(data, clf):
|
|||||||
clf.plot("AODE Iris")
|
clf.plot("AODE Iris")
|
||||||
|
|
||||||
|
|
||||||
def test_AODE_version(clf):
|
def test_AODE_version(clf, data):
|
||||||
"""Check AODE version."""
|
"""Check AODE version."""
|
||||||
assert __version__ == clf.version()
|
assert __version__ == clf.version()
|
||||||
|
dataset = load_iris(as_frame=True)
|
||||||
|
clf.fit(*data, features=dataset["feature_names"])
|
||||||
|
assert __version__ == clf.version()
|
||||||
|
|
||||||
|
|
||||||
def test_AODE_nodes_edges(clf, data):
|
def test_AODE_nodes_edges(clf, data):
|
||||||
@@ -71,12 +74,11 @@ def test_AODE_states(clf, data):
|
|||||||
def test_AODE_classifier(data, clf):
|
def test_AODE_classifier(data, clf):
|
||||||
clf.fit(*data)
|
clf.fit(*data)
|
||||||
attribs = [
|
attribs = [
|
||||||
"classes_",
|
|
||||||
"X_",
|
|
||||||
"y_",
|
|
||||||
"feature_names_in_",
|
"feature_names_in_",
|
||||||
"class_name_",
|
"class_name_",
|
||||||
"n_features_in_",
|
"n_features_in_",
|
||||||
|
"X_",
|
||||||
|
"y_",
|
||||||
]
|
]
|
||||||
for attr in attribs:
|
for attr in attribs:
|
||||||
assert hasattr(clf, attr)
|
assert hasattr(clf, attr)
|
||||||
|
@@ -52,8 +52,11 @@ def test_AODENew_plot(data, clf):
|
|||||||
clf.plot("AODE Iris")
|
clf.plot("AODE Iris")
|
||||||
|
|
||||||
|
|
||||||
def test_AODENew_version(clf):
|
def test_AODENew_version(clf, data):
|
||||||
"""Check AODE version."""
|
"""Check AODENew version."""
|
||||||
|
assert __version__ == clf.version()
|
||||||
|
dataset = load_iris(as_frame=True)
|
||||||
|
clf.fit(*data, features=dataset["feature_names"])
|
||||||
assert __version__ == clf.version()
|
assert __version__ == clf.version()
|
||||||
|
|
||||||
|
|
||||||
@@ -73,12 +76,11 @@ def test_AODENew_states(clf, data):
|
|||||||
def test_AODENew_classifier(data, clf):
|
def test_AODENew_classifier(data, clf):
|
||||||
clf.fit(*data)
|
clf.fit(*data)
|
||||||
attribs = [
|
attribs = [
|
||||||
"classes_",
|
|
||||||
"X_",
|
|
||||||
"y_",
|
|
||||||
"feature_names_in_",
|
"feature_names_in_",
|
||||||
"class_name_",
|
"class_name_",
|
||||||
"n_features_in_",
|
"n_features_in_",
|
||||||
|
"X_",
|
||||||
|
"y_",
|
||||||
]
|
]
|
||||||
for attr in attribs:
|
for attr in attribs:
|
||||||
assert hasattr(clf, attr)
|
assert hasattr(clf, attr)
|
||||||
|
@@ -1,8 +1,23 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from sklearn.utils.estimator_checks import check_estimator
|
from sklearn.utils.estimator_checks import check_estimator
|
||||||
|
|
||||||
from bayesclass.clfs import TAN, KDB, AODE
|
from bayesclass.clfs import BayesBase, TAN, KDB, AODE
|
||||||
|
|
||||||
|
|
||||||
|
def test_more_tags():
|
||||||
|
expected = {
|
||||||
|
"requires_positive_X": True,
|
||||||
|
"requires_positive_y": True,
|
||||||
|
"preserve_dtype": [np.int32, np.int64],
|
||||||
|
"requires_y": True,
|
||||||
|
}
|
||||||
|
clf = BayesBase(None, True)
|
||||||
|
computed = clf._more_tags()
|
||||||
|
for key, value in expected.items():
|
||||||
|
assert key in computed
|
||||||
|
assert computed[key] == value
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("estimators", [TAN(), KDB(k=2), AODE()])
|
# @pytest.mark.parametrize("estimators", [TAN(), KDB(k=2), AODE()])
|
||||||
|
Reference in New Issue
Block a user