Add version and example

This commit is contained in:
2022-11-13 14:24:06 +01:00
parent f9f91c54a7
commit 6070c2159a
6 changed files with 52 additions and 19 deletions

View File

@@ -1,5 +1,9 @@
from .bayesclass import TAN
from ._version import __version__
__author__ = "Ricardo Montañana Gómez"
__copyright__ = "Copyright 2020-2023, Ricardo Montañana Gómez"
__license__ = "MIT License"
__author_email__ = "ricardo.montanana@alu.uclm.es"
__all__ = ["TAN", "__version__"]

View File

@@ -11,6 +11,7 @@ import networkx as nx
from pgmpy.estimators import TreeSearch, BayesianEstimator
from pgmpy.models import BayesianNetwork
import matplotlib.pyplot as plt
from ._version import __version__
class TAN(ClassifierMixin, BaseEstimator):
@@ -40,6 +41,21 @@ class TAN(ClassifierMixin, BaseEstimator):
self.show_progress = show_progress
self.random_state = random_state
def _more_tags(self):
import numpy as np
return {
"requires_positive_X": True,
"requires_positive_y": True,
"preserve_dtype": [np.int64, np.int32],
"requires_y": True,
}
@staticmethod
def version() -> str:
"""Return the version of the package."""
return __version__
def __check_params_fit(self, X, y, kwargs):
# Check that X and y have correct shape
X, y = check_X_y(X, y)
@@ -48,7 +64,7 @@ class TAN(ClassifierMixin, BaseEstimator):
# Default values
self.class_name_ = "class"
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
self.head_ = None
self.head_ = 0
expected_args = ["class_name", "features", "head"]
for key, value in kwargs.items():
if key in expected_args:
@@ -65,6 +81,7 @@ class TAN(ClassifierMixin, BaseEstimator):
)
if self.head_ is not None and self.head_ >= len(self.features_):
raise ValueError("Head index out of range")
return X, y
def fit(self, X, y, **kwargs):
"""A reference implementation of a fitting function for a classifier.
@@ -104,13 +121,11 @@ class TAN(ClassifierMixin, BaseEstimator):
>>> model.fit(train_data, train_y, features=features, class_name='E')
TAN(random_state=17)
"""
self.__check_params_fit(X, y, kwargs)
X_, y_ = self.__check_params_fit(X, y, kwargs)
# Store the information needed to build the model
self.X_ = X
self.y_ = y.astype(int)
self.dataset_ = pd.DataFrame(
self.X_, columns=self.features_, dtype="int16"
)
self.X_ = X_
self.y_ = y_
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
self.dataset_[self.class_name_] = self.y_
# Build the DAG
self.__build()
@@ -136,7 +151,7 @@ class TAN(ClassifierMixin, BaseEstimator):
List
List of edges
"""
head = 0 if self.head_ is None else self.head_
head = self.head_
if self.simple_init:
first_node = self.features_[head]
return [
@@ -158,15 +173,12 @@ class TAN(ClassifierMixin, BaseEstimator):
# initialize a complete network with all edges
self.model_.add_edges_from(self.__initial_edges())
# learn graph structure
root_node = None if self.head_ is None else self.features_[self.head_]
est = TreeSearch(self.dataset_, root_node=root_node)
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
self.dag_ = est.estimate(
estimator_type="tan",
class_node=self.class_name_,
show_progress=self.show_progress,
)
if self.head_ is None:
self.head_ = est.root_node
def __train(self):
self.model_ = BayesianNetwork(
@@ -174,11 +186,13 @@ class TAN(ClassifierMixin, BaseEstimator):
)
self.model_.fit(
self.dataset_,
# estimator=MaximumLikelihoodEstimator,
estimator=BayesianEstimator,
prior_type="K2",
)
def nodes_leaves(self):
return 0, 0
def plot(self, title=""):
nx.draw_circular(
self.model_,

View File

@@ -7,6 +7,7 @@ from matplotlib.testing.conftest import mpl_test_settings
from bayesclass import TAN
from .._version import __version__
@pytest.fixture
@@ -16,7 +17,7 @@ def data():
return enc.fit_transform(X), y
def test_TAN_constructor():
def test_TAN_default_hyperparameters(data):
clf = TAN()
# Test default values of hyperparameters
assert clf.simple_init
@@ -26,6 +27,21 @@ def test_TAN_constructor():
assert clf.simple_init
assert clf.show_progress
assert clf.random_state == 17
clf.fit(*data)
assert clf.head_ == 0
assert clf.class_name_ == "class"
assert clf.features_ == [
"feature_0",
"feature_1",
"feature_2",
"feature_3",
]
def test_TAN_version():
"""Check TAN version."""
clf = TAN()
assert __version__ == clf.version()
def test_TAN_random_head(data):

View File

@@ -7,5 +7,4 @@ from bayesclass import TAN
@pytest.mark.parametrize("estimator", [TAN()])
def test_all_estimators(estimator):
# return check_estimator(estimator)
assert True
return check_estimator(estimator)

View File

@@ -1,7 +1,7 @@
import sys
import time
from sklearn.model_selection import cross_val_score, StratifiedKFold
from benchmark import Discretizer
from benchmark import Datasets
from bayesclass import TAN
import warnings
@@ -15,7 +15,7 @@ start = time.time()
random_state = 17
name = sys.argv[1]
n_folds = int(sys.argv[2]) if len(sys.argv) == 3 else 5
dt = Discretizer()
dt = Datasets()
name_list = list(dt) if name == "all" else [name]
print(f"Accuracy in {n_folds} folds stratified crossvalidation")
for name in name_list:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 45 KiB

After

Width:  |  Height:  |  Size: 45 KiB