mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-17 00:26:10 +00:00
Add version and example
This commit is contained in:
@@ -1,5 +1,9 @@
|
|||||||
from .bayesclass import TAN
|
from .bayesclass import TAN
|
||||||
|
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
|
|
||||||
|
__author__ = "Ricardo Montañana Gómez"
|
||||||
|
__copyright__ = "Copyright 2020-2023, Ricardo Montañana Gómez"
|
||||||
|
__license__ = "MIT License"
|
||||||
|
__author_email__ = "ricardo.montanana@alu.uclm.es"
|
||||||
|
|
||||||
__all__ = ["TAN", "__version__"]
|
__all__ = ["TAN", "__version__"]
|
||||||
|
@@ -11,6 +11,7 @@ import networkx as nx
|
|||||||
from pgmpy.estimators import TreeSearch, BayesianEstimator
|
from pgmpy.estimators import TreeSearch, BayesianEstimator
|
||||||
from pgmpy.models import BayesianNetwork
|
from pgmpy.models import BayesianNetwork
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from ._version import __version__
|
||||||
|
|
||||||
|
|
||||||
class TAN(ClassifierMixin, BaseEstimator):
|
class TAN(ClassifierMixin, BaseEstimator):
|
||||||
@@ -40,6 +41,21 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
self.show_progress = show_progress
|
self.show_progress = show_progress
|
||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
|
|
||||||
|
def _more_tags(self):
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
return {
|
||||||
|
"requires_positive_X": True,
|
||||||
|
"requires_positive_y": True,
|
||||||
|
"preserve_dtype": [np.int64, np.int32],
|
||||||
|
"requires_y": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def version() -> str:
|
||||||
|
"""Return the version of the package."""
|
||||||
|
return __version__
|
||||||
|
|
||||||
def __check_params_fit(self, X, y, kwargs):
|
def __check_params_fit(self, X, y, kwargs):
|
||||||
# Check that X and y have correct shape
|
# Check that X and y have correct shape
|
||||||
X, y = check_X_y(X, y)
|
X, y = check_X_y(X, y)
|
||||||
@@ -48,7 +64,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
# Default values
|
# Default values
|
||||||
self.class_name_ = "class"
|
self.class_name_ = "class"
|
||||||
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
|
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
|
||||||
self.head_ = None
|
self.head_ = 0
|
||||||
expected_args = ["class_name", "features", "head"]
|
expected_args = ["class_name", "features", "head"]
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key in expected_args:
|
if key in expected_args:
|
||||||
@@ -65,6 +81,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
)
|
)
|
||||||
if self.head_ is not None and self.head_ >= len(self.features_):
|
if self.head_ is not None and self.head_ >= len(self.features_):
|
||||||
raise ValueError("Head index out of range")
|
raise ValueError("Head index out of range")
|
||||||
|
return X, y
|
||||||
|
|
||||||
def fit(self, X, y, **kwargs):
|
def fit(self, X, y, **kwargs):
|
||||||
"""A reference implementation of a fitting function for a classifier.
|
"""A reference implementation of a fitting function for a classifier.
|
||||||
@@ -104,13 +121,11 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
>>> model.fit(train_data, train_y, features=features, class_name='E')
|
>>> model.fit(train_data, train_y, features=features, class_name='E')
|
||||||
TAN(random_state=17)
|
TAN(random_state=17)
|
||||||
"""
|
"""
|
||||||
self.__check_params_fit(X, y, kwargs)
|
X_, y_ = self.__check_params_fit(X, y, kwargs)
|
||||||
# Store the information needed to build the model
|
# Store the information needed to build the model
|
||||||
self.X_ = X
|
self.X_ = X_
|
||||||
self.y_ = y.astype(int)
|
self.y_ = y_
|
||||||
self.dataset_ = pd.DataFrame(
|
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
|
||||||
self.X_, columns=self.features_, dtype="int16"
|
|
||||||
)
|
|
||||||
self.dataset_[self.class_name_] = self.y_
|
self.dataset_[self.class_name_] = self.y_
|
||||||
# Build the DAG
|
# Build the DAG
|
||||||
self.__build()
|
self.__build()
|
||||||
@@ -136,7 +151,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
List
|
List
|
||||||
List of edges
|
List of edges
|
||||||
"""
|
"""
|
||||||
head = 0 if self.head_ is None else self.head_
|
head = self.head_
|
||||||
if self.simple_init:
|
if self.simple_init:
|
||||||
first_node = self.features_[head]
|
first_node = self.features_[head]
|
||||||
return [
|
return [
|
||||||
@@ -158,15 +173,12 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
# initialize a complete network with all edges
|
# initialize a complete network with all edges
|
||||||
self.model_.add_edges_from(self.__initial_edges())
|
self.model_.add_edges_from(self.__initial_edges())
|
||||||
# learn graph structure
|
# learn graph structure
|
||||||
root_node = None if self.head_ is None else self.features_[self.head_]
|
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
||||||
est = TreeSearch(self.dataset_, root_node=root_node)
|
|
||||||
self.dag_ = est.estimate(
|
self.dag_ = est.estimate(
|
||||||
estimator_type="tan",
|
estimator_type="tan",
|
||||||
class_node=self.class_name_,
|
class_node=self.class_name_,
|
||||||
show_progress=self.show_progress,
|
show_progress=self.show_progress,
|
||||||
)
|
)
|
||||||
if self.head_ is None:
|
|
||||||
self.head_ = est.root_node
|
|
||||||
|
|
||||||
def __train(self):
|
def __train(self):
|
||||||
self.model_ = BayesianNetwork(
|
self.model_ = BayesianNetwork(
|
||||||
@@ -174,11 +186,13 @@ class TAN(ClassifierMixin, BaseEstimator):
|
|||||||
)
|
)
|
||||||
self.model_.fit(
|
self.model_.fit(
|
||||||
self.dataset_,
|
self.dataset_,
|
||||||
# estimator=MaximumLikelihoodEstimator,
|
|
||||||
estimator=BayesianEstimator,
|
estimator=BayesianEstimator,
|
||||||
prior_type="K2",
|
prior_type="K2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def nodes_leaves(self):
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
def plot(self, title=""):
|
def plot(self, title=""):
|
||||||
nx.draw_circular(
|
nx.draw_circular(
|
||||||
self.model_,
|
self.model_,
|
||||||
|
@@ -7,6 +7,7 @@ from matplotlib.testing.conftest import mpl_test_settings
|
|||||||
|
|
||||||
|
|
||||||
from bayesclass import TAN
|
from bayesclass import TAN
|
||||||
|
from .._version import __version__
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -16,7 +17,7 @@ def data():
|
|||||||
return enc.fit_transform(X), y
|
return enc.fit_transform(X), y
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_constructor():
|
def test_TAN_default_hyperparameters(data):
|
||||||
clf = TAN()
|
clf = TAN()
|
||||||
# Test default values of hyperparameters
|
# Test default values of hyperparameters
|
||||||
assert clf.simple_init
|
assert clf.simple_init
|
||||||
@@ -26,6 +27,21 @@ def test_TAN_constructor():
|
|||||||
assert clf.simple_init
|
assert clf.simple_init
|
||||||
assert clf.show_progress
|
assert clf.show_progress
|
||||||
assert clf.random_state == 17
|
assert clf.random_state == 17
|
||||||
|
clf.fit(*data)
|
||||||
|
assert clf.head_ == 0
|
||||||
|
assert clf.class_name_ == "class"
|
||||||
|
assert clf.features_ == [
|
||||||
|
"feature_0",
|
||||||
|
"feature_1",
|
||||||
|
"feature_2",
|
||||||
|
"feature_3",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_TAN_version():
|
||||||
|
"""Check TAN version."""
|
||||||
|
clf = TAN()
|
||||||
|
assert __version__ == clf.version()
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_random_head(data):
|
def test_TAN_random_head(data):
|
||||||
|
@@ -7,5 +7,4 @@ from bayesclass import TAN
|
|||||||
|
|
||||||
@pytest.mark.parametrize("estimator", [TAN()])
|
@pytest.mark.parametrize("estimator", [TAN()])
|
||||||
def test_all_estimators(estimator):
|
def test_all_estimators(estimator):
|
||||||
# return check_estimator(estimator)
|
return check_estimator(estimator)
|
||||||
assert True
|
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from sklearn.model_selection import cross_val_score, StratifiedKFold
|
from sklearn.model_selection import cross_val_score, StratifiedKFold
|
||||||
from benchmark import Discretizer
|
from benchmark import Datasets
|
||||||
from bayesclass import TAN
|
from bayesclass import TAN
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ start = time.time()
|
|||||||
random_state = 17
|
random_state = 17
|
||||||
name = sys.argv[1]
|
name = sys.argv[1]
|
||||||
n_folds = int(sys.argv[2]) if len(sys.argv) == 3 else 5
|
n_folds = int(sys.argv[2]) if len(sys.argv) == 3 else 5
|
||||||
dt = Discretizer()
|
dt = Datasets()
|
||||||
name_list = list(dt) if name == "all" else [name]
|
name_list = list(dt) if name == "all" else [name]
|
||||||
print(f"Accuracy in {n_folds} folds stratified crossvalidation")
|
print(f"Accuracy in {n_folds} folds stratified crossvalidation")
|
||||||
for name in name_list:
|
for name in name_list:
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 45 KiB After Width: | Height: | Size: 45 KiB |
Reference in New Issue
Block a user