diff --git a/bayesclass/__init__.py b/bayesclass/__init__.py index aadf6fe..cc119e8 100644 --- a/bayesclass/__init__.py +++ b/bayesclass/__init__.py @@ -1,5 +1,9 @@ from .bayesclass import TAN - from ._version import __version__ +__author__ = "Ricardo Montañana Gómez" +__copyright__ = "Copyright 2020-2023, Ricardo Montañana Gómez" +__license__ = "MIT License" +__author_email__ = "ricardo.montanana@alu.uclm.es" + __all__ = ["TAN", "__version__"] diff --git a/bayesclass/bayesclass.py b/bayesclass/bayesclass.py index e2c9fba..ff2d6b1 100644 --- a/bayesclass/bayesclass.py +++ b/bayesclass/bayesclass.py @@ -11,6 +11,7 @@ import networkx as nx from pgmpy.estimators import TreeSearch, BayesianEstimator from pgmpy.models import BayesianNetwork import matplotlib.pyplot as plt +from ._version import __version__ class TAN(ClassifierMixin, BaseEstimator): @@ -40,6 +41,21 @@ class TAN(ClassifierMixin, BaseEstimator): self.show_progress = show_progress self.random_state = random_state + def _more_tags(self): + import numpy as np + + return { + "requires_positive_X": True, + "requires_positive_y": True, + "preserve_dtype": [np.int64, np.int32], + "requires_y": True, + } + + @staticmethod + def version() -> str: + """Return the version of the package.""" + return __version__ + def __check_params_fit(self, X, y, kwargs): # Check that X and y have correct shape X, y = check_X_y(X, y) @@ -48,7 +64,7 @@ class TAN(ClassifierMixin, BaseEstimator): # Default values self.class_name_ = "class" self.features_ = [f"feature_{i}" for i in range(X.shape[1])] - self.head_ = None + self.head_ = 0 expected_args = ["class_name", "features", "head"] for key, value in kwargs.items(): if key in expected_args: @@ -65,6 +81,7 @@ class TAN(ClassifierMixin, BaseEstimator): ) if self.head_ is not None and self.head_ >= len(self.features_): raise ValueError("Head index out of range") + return X, y def fit(self, X, y, **kwargs): """A reference implementation of a fitting function for a classifier. @@ -104,13 +121,11 @@ class TAN(ClassifierMixin, BaseEstimator): >>> model.fit(train_data, train_y, features=features, class_name='E') TAN(random_state=17) """ - self.__check_params_fit(X, y, kwargs) + X_, y_ = self.__check_params_fit(X, y, kwargs) # Store the information needed to build the model - self.X_ = X - self.y_ = y.astype(int) - self.dataset_ = pd.DataFrame( - self.X_, columns=self.features_, dtype="int16" - ) + self.X_ = X_ + self.y_ = y_ + self.dataset_ = pd.DataFrame(self.X_, columns=self.features_) self.dataset_[self.class_name_] = self.y_ # Build the DAG self.__build() @@ -136,7 +151,7 @@ class TAN(ClassifierMixin, BaseEstimator): List List of edges """ - head = 0 if self.head_ is None else self.head_ + head = self.head_ if self.simple_init: first_node = self.features_[head] return [ @@ -158,15 +173,12 @@ class TAN(ClassifierMixin, BaseEstimator): # initialize a complete network with all edges self.model_.add_edges_from(self.__initial_edges()) # learn graph structure - root_node = None if self.head_ is None else self.features_[self.head_] - est = TreeSearch(self.dataset_, root_node=root_node) + est = TreeSearch(self.dataset_, root_node=self.features_[self.head_]) self.dag_ = est.estimate( estimator_type="tan", class_node=self.class_name_, show_progress=self.show_progress, ) - if self.head_ is None: - self.head_ = est.root_node def __train(self): self.model_ = BayesianNetwork( @@ -174,11 +186,13 @@ class TAN(ClassifierMixin, BaseEstimator): ) self.model_.fit( self.dataset_, - # estimator=MaximumLikelihoodEstimator, estimator=BayesianEstimator, prior_type="K2", ) + def nodes_leaves(self): + return 0, 0 + def plot(self, title=""): nx.draw_circular( self.model_, diff --git a/bayesclass/tests/test_bayesclass.py b/bayesclass/tests/test_bayesclass.py index 860c1cd..62d807b 100644 --- a/bayesclass/tests/test_bayesclass.py +++ b/bayesclass/tests/test_bayesclass.py @@ -7,6 +7,7 @@ from matplotlib.testing.conftest import mpl_test_settings from bayesclass import TAN +from .._version import __version__ @pytest.fixture @@ -16,7 +17,7 @@ def data(): return enc.fit_transform(X), y -def test_TAN_constructor(): +def test_TAN_default_hyperparameters(data): clf = TAN() # Test default values of hyperparameters assert clf.simple_init @@ -26,6 +27,21 @@ def test_TAN_constructor(): assert clf.simple_init assert clf.show_progress assert clf.random_state == 17 + clf.fit(*data) + assert clf.head_ == 0 + assert clf.class_name_ == "class" + assert clf.features_ == [ + "feature_0", + "feature_1", + "feature_2", + "feature_3", + ] + + +def test_TAN_version(): + """Check TAN version.""" + clf = TAN() + assert __version__ == clf.version() def test_TAN_random_head(data): diff --git a/bayesclass/tests/test_common.py b/bayesclass/tests/test_common.py index 9cc866e..003dc14 100644 --- a/bayesclass/tests/test_common.py +++ b/bayesclass/tests/test_common.py @@ -7,5 +7,4 @@ from bayesclass import TAN @pytest.mark.parametrize("estimator", [TAN()]) def test_all_estimators(estimator): - # return check_estimator(estimator) - assert True + return check_estimator(estimator) diff --git a/example.py b/example.py index e9e1158..5f4e164 100644 --- a/example.py +++ b/example.py @@ -1,7 +1,7 @@ import sys import time from sklearn.model_selection import cross_val_score, StratifiedKFold -from benchmark import Discretizer +from benchmark import Datasets from bayesclass import TAN import warnings @@ -15,7 +15,7 @@ start = time.time() random_state = 17 name = sys.argv[1] n_folds = int(sys.argv[2]) if len(sys.argv) == 3 else 5 -dt = Discretizer() +dt = Datasets() name_list = list(dt) if name == "all" else [name] print(f"Accuracy in {n_folds} folds stratified crossvalidation") for name in name_list: diff --git a/result_images/test_bayesclass/line_dashes.png b/result_images/test_bayesclass/line_dashes.png index 9c989ec..4322d0e 100644 Binary files a/result_images/test_bayesclass/line_dashes.png and b/result_images/test_bayesclass/line_dashes.png differ