Add version and example

This commit is contained in:
2022-11-13 14:24:06 +01:00
parent f9f91c54a7
commit 6070c2159a
6 changed files with 52 additions and 19 deletions

View File

@@ -1,5 +1,9 @@
from .bayesclass import TAN from .bayesclass import TAN
from ._version import __version__ from ._version import __version__
__author__ = "Ricardo Montañana Gómez"
__copyright__ = "Copyright 2020-2023, Ricardo Montañana Gómez"
__license__ = "MIT License"
__author_email__ = "ricardo.montanana@alu.uclm.es"
__all__ = ["TAN", "__version__"] __all__ = ["TAN", "__version__"]

View File

@@ -11,6 +11,7 @@ import networkx as nx
from pgmpy.estimators import TreeSearch, BayesianEstimator from pgmpy.estimators import TreeSearch, BayesianEstimator
from pgmpy.models import BayesianNetwork from pgmpy.models import BayesianNetwork
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from ._version import __version__
class TAN(ClassifierMixin, BaseEstimator): class TAN(ClassifierMixin, BaseEstimator):
@@ -40,6 +41,21 @@ class TAN(ClassifierMixin, BaseEstimator):
self.show_progress = show_progress self.show_progress = show_progress
self.random_state = random_state self.random_state = random_state
def _more_tags(self):
import numpy as np
return {
"requires_positive_X": True,
"requires_positive_y": True,
"preserve_dtype": [np.int64, np.int32],
"requires_y": True,
}
@staticmethod
def version() -> str:
"""Return the version of the package."""
return __version__
def __check_params_fit(self, X, y, kwargs): def __check_params_fit(self, X, y, kwargs):
# Check that X and y have correct shape # Check that X and y have correct shape
X, y = check_X_y(X, y) X, y = check_X_y(X, y)
@@ -48,7 +64,7 @@ class TAN(ClassifierMixin, BaseEstimator):
# Default values # Default values
self.class_name_ = "class" self.class_name_ = "class"
self.features_ = [f"feature_{i}" for i in range(X.shape[1])] self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
self.head_ = None self.head_ = 0
expected_args = ["class_name", "features", "head"] expected_args = ["class_name", "features", "head"]
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in expected_args: if key in expected_args:
@@ -65,6 +81,7 @@ class TAN(ClassifierMixin, BaseEstimator):
) )
if self.head_ is not None and self.head_ >= len(self.features_): if self.head_ is not None and self.head_ >= len(self.features_):
raise ValueError("Head index out of range") raise ValueError("Head index out of range")
return X, y
def fit(self, X, y, **kwargs): def fit(self, X, y, **kwargs):
"""A reference implementation of a fitting function for a classifier. """A reference implementation of a fitting function for a classifier.
@@ -104,13 +121,11 @@ class TAN(ClassifierMixin, BaseEstimator):
>>> model.fit(train_data, train_y, features=features, class_name='E') >>> model.fit(train_data, train_y, features=features, class_name='E')
TAN(random_state=17) TAN(random_state=17)
""" """
self.__check_params_fit(X, y, kwargs) X_, y_ = self.__check_params_fit(X, y, kwargs)
# Store the information needed to build the model # Store the information needed to build the model
self.X_ = X self.X_ = X_
self.y_ = y.astype(int) self.y_ = y_
self.dataset_ = pd.DataFrame( self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
self.X_, columns=self.features_, dtype="int16"
)
self.dataset_[self.class_name_] = self.y_ self.dataset_[self.class_name_] = self.y_
# Build the DAG # Build the DAG
self.__build() self.__build()
@@ -136,7 +151,7 @@ class TAN(ClassifierMixin, BaseEstimator):
List List
List of edges List of edges
""" """
head = 0 if self.head_ is None else self.head_ head = self.head_
if self.simple_init: if self.simple_init:
first_node = self.features_[head] first_node = self.features_[head]
return [ return [
@@ -158,15 +173,12 @@ class TAN(ClassifierMixin, BaseEstimator):
# initialize a complete network with all edges # initialize a complete network with all edges
self.model_.add_edges_from(self.__initial_edges()) self.model_.add_edges_from(self.__initial_edges())
# learn graph structure # learn graph structure
root_node = None if self.head_ is None else self.features_[self.head_] est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
est = TreeSearch(self.dataset_, root_node=root_node)
self.dag_ = est.estimate( self.dag_ = est.estimate(
estimator_type="tan", estimator_type="tan",
class_node=self.class_name_, class_node=self.class_name_,
show_progress=self.show_progress, show_progress=self.show_progress,
) )
if self.head_ is None:
self.head_ = est.root_node
def __train(self): def __train(self):
self.model_ = BayesianNetwork( self.model_ = BayesianNetwork(
@@ -174,11 +186,13 @@ class TAN(ClassifierMixin, BaseEstimator):
) )
self.model_.fit( self.model_.fit(
self.dataset_, self.dataset_,
# estimator=MaximumLikelihoodEstimator,
estimator=BayesianEstimator, estimator=BayesianEstimator,
prior_type="K2", prior_type="K2",
) )
def nodes_leaves(self):
return 0, 0
def plot(self, title=""): def plot(self, title=""):
nx.draw_circular( nx.draw_circular(
self.model_, self.model_,

View File

@@ -7,6 +7,7 @@ from matplotlib.testing.conftest import mpl_test_settings
from bayesclass import TAN from bayesclass import TAN
from .._version import __version__
@pytest.fixture @pytest.fixture
@@ -16,7 +17,7 @@ def data():
return enc.fit_transform(X), y return enc.fit_transform(X), y
def test_TAN_constructor(): def test_TAN_default_hyperparameters(data):
clf = TAN() clf = TAN()
# Test default values of hyperparameters # Test default values of hyperparameters
assert clf.simple_init assert clf.simple_init
@@ -26,6 +27,21 @@ def test_TAN_constructor():
assert clf.simple_init assert clf.simple_init
assert clf.show_progress assert clf.show_progress
assert clf.random_state == 17 assert clf.random_state == 17
clf.fit(*data)
assert clf.head_ == 0
assert clf.class_name_ == "class"
assert clf.features_ == [
"feature_0",
"feature_1",
"feature_2",
"feature_3",
]
def test_TAN_version():
"""Check TAN version."""
clf = TAN()
assert __version__ == clf.version()
def test_TAN_random_head(data): def test_TAN_random_head(data):

View File

@@ -7,5 +7,4 @@ from bayesclass import TAN
@pytest.mark.parametrize("estimator", [TAN()]) @pytest.mark.parametrize("estimator", [TAN()])
def test_all_estimators(estimator): def test_all_estimators(estimator):
# return check_estimator(estimator) return check_estimator(estimator)
assert True

View File

@@ -1,7 +1,7 @@
import sys import sys
import time import time
from sklearn.model_selection import cross_val_score, StratifiedKFold from sklearn.model_selection import cross_val_score, StratifiedKFold
from benchmark import Discretizer from benchmark import Datasets
from bayesclass import TAN from bayesclass import TAN
import warnings import warnings
@@ -15,7 +15,7 @@ start = time.time()
random_state = 17 random_state = 17
name = sys.argv[1] name = sys.argv[1]
n_folds = int(sys.argv[2]) if len(sys.argv) == 3 else 5 n_folds = int(sys.argv[2]) if len(sys.argv) == 3 else 5
dt = Discretizer() dt = Datasets()
name_list = list(dt) if name == "all" else [name] name_list = list(dt) if name == "all" else [name]
print(f"Accuracy in {n_folds} folds stratified crossvalidation") print(f"Accuracy in {n_folds} folds stratified crossvalidation")
for name in name_list: for name in name_list:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 45 KiB

After

Width:  |  Height:  |  Size: 45 KiB