mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-15 23:55:57 +00:00
Add version and example
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
from .bayesclass import TAN
|
||||
|
||||
from ._version import __version__
|
||||
|
||||
__author__ = "Ricardo Montañana Gómez"
|
||||
__copyright__ = "Copyright 2020-2023, Ricardo Montañana Gómez"
|
||||
__license__ = "MIT License"
|
||||
__author_email__ = "ricardo.montanana@alu.uclm.es"
|
||||
|
||||
__all__ = ["TAN", "__version__"]
|
||||
|
@@ -11,6 +11,7 @@ import networkx as nx
|
||||
from pgmpy.estimators import TreeSearch, BayesianEstimator
|
||||
from pgmpy.models import BayesianNetwork
|
||||
import matplotlib.pyplot as plt
|
||||
from ._version import __version__
|
||||
|
||||
|
||||
class TAN(ClassifierMixin, BaseEstimator):
|
||||
@@ -40,6 +41,21 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
self.show_progress = show_progress
|
||||
self.random_state = random_state
|
||||
|
||||
def _more_tags(self):
|
||||
import numpy as np
|
||||
|
||||
return {
|
||||
"requires_positive_X": True,
|
||||
"requires_positive_y": True,
|
||||
"preserve_dtype": [np.int64, np.int32],
|
||||
"requires_y": True,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def version() -> str:
|
||||
"""Return the version of the package."""
|
||||
return __version__
|
||||
|
||||
def __check_params_fit(self, X, y, kwargs):
|
||||
# Check that X and y have correct shape
|
||||
X, y = check_X_y(X, y)
|
||||
@@ -48,7 +64,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
# Default values
|
||||
self.class_name_ = "class"
|
||||
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
|
||||
self.head_ = None
|
||||
self.head_ = 0
|
||||
expected_args = ["class_name", "features", "head"]
|
||||
for key, value in kwargs.items():
|
||||
if key in expected_args:
|
||||
@@ -65,6 +81,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
)
|
||||
if self.head_ is not None and self.head_ >= len(self.features_):
|
||||
raise ValueError("Head index out of range")
|
||||
return X, y
|
||||
|
||||
def fit(self, X, y, **kwargs):
|
||||
"""A reference implementation of a fitting function for a classifier.
|
||||
@@ -104,13 +121,11 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
>>> model.fit(train_data, train_y, features=features, class_name='E')
|
||||
TAN(random_state=17)
|
||||
"""
|
||||
self.__check_params_fit(X, y, kwargs)
|
||||
X_, y_ = self.__check_params_fit(X, y, kwargs)
|
||||
# Store the information needed to build the model
|
||||
self.X_ = X
|
||||
self.y_ = y.astype(int)
|
||||
self.dataset_ = pd.DataFrame(
|
||||
self.X_, columns=self.features_, dtype="int16"
|
||||
)
|
||||
self.X_ = X_
|
||||
self.y_ = y_
|
||||
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
|
||||
self.dataset_[self.class_name_] = self.y_
|
||||
# Build the DAG
|
||||
self.__build()
|
||||
@@ -136,7 +151,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
List
|
||||
List of edges
|
||||
"""
|
||||
head = 0 if self.head_ is None else self.head_
|
||||
head = self.head_
|
||||
if self.simple_init:
|
||||
first_node = self.features_[head]
|
||||
return [
|
||||
@@ -158,15 +173,12 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
# initialize a complete network with all edges
|
||||
self.model_.add_edges_from(self.__initial_edges())
|
||||
# learn graph structure
|
||||
root_node = None if self.head_ is None else self.features_[self.head_]
|
||||
est = TreeSearch(self.dataset_, root_node=root_node)
|
||||
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
||||
self.dag_ = est.estimate(
|
||||
estimator_type="tan",
|
||||
class_node=self.class_name_,
|
||||
show_progress=self.show_progress,
|
||||
)
|
||||
if self.head_ is None:
|
||||
self.head_ = est.root_node
|
||||
|
||||
def __train(self):
|
||||
self.model_ = BayesianNetwork(
|
||||
@@ -174,11 +186,13 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
)
|
||||
self.model_.fit(
|
||||
self.dataset_,
|
||||
# estimator=MaximumLikelihoodEstimator,
|
||||
estimator=BayesianEstimator,
|
||||
prior_type="K2",
|
||||
)
|
||||
|
||||
def nodes_leaves(self):
|
||||
return 0, 0
|
||||
|
||||
def plot(self, title=""):
|
||||
nx.draw_circular(
|
||||
self.model_,
|
||||
|
@@ -7,6 +7,7 @@ from matplotlib.testing.conftest import mpl_test_settings
|
||||
|
||||
|
||||
from bayesclass import TAN
|
||||
from .._version import __version__
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -16,7 +17,7 @@ def data():
|
||||
return enc.fit_transform(X), y
|
||||
|
||||
|
||||
def test_TAN_constructor():
|
||||
def test_TAN_default_hyperparameters(data):
|
||||
clf = TAN()
|
||||
# Test default values of hyperparameters
|
||||
assert clf.simple_init
|
||||
@@ -26,6 +27,21 @@ def test_TAN_constructor():
|
||||
assert clf.simple_init
|
||||
assert clf.show_progress
|
||||
assert clf.random_state == 17
|
||||
clf.fit(*data)
|
||||
assert clf.head_ == 0
|
||||
assert clf.class_name_ == "class"
|
||||
assert clf.features_ == [
|
||||
"feature_0",
|
||||
"feature_1",
|
||||
"feature_2",
|
||||
"feature_3",
|
||||
]
|
||||
|
||||
|
||||
def test_TAN_version():
|
||||
"""Check TAN version."""
|
||||
clf = TAN()
|
||||
assert __version__ == clf.version()
|
||||
|
||||
|
||||
def test_TAN_random_head(data):
|
||||
|
@@ -7,5 +7,4 @@ from bayesclass import TAN
|
||||
|
||||
@pytest.mark.parametrize("estimator", [TAN()])
|
||||
def test_all_estimators(estimator):
|
||||
# return check_estimator(estimator)
|
||||
assert True
|
||||
return check_estimator(estimator)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
import time
|
||||
from sklearn.model_selection import cross_val_score, StratifiedKFold
|
||||
from benchmark import Discretizer
|
||||
from benchmark import Datasets
|
||||
from bayesclass import TAN
|
||||
import warnings
|
||||
|
||||
@@ -15,7 +15,7 @@ start = time.time()
|
||||
random_state = 17
|
||||
name = sys.argv[1]
|
||||
n_folds = int(sys.argv[2]) if len(sys.argv) == 3 else 5
|
||||
dt = Discretizer()
|
||||
dt = Datasets()
|
||||
name_list = list(dt) if name == "all" else [name]
|
||||
print(f"Accuracy in {n_folds} folds stratified crossvalidation")
|
||||
for name in name_list:
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 45 KiB After Width: | Height: | Size: 45 KiB |
Reference in New Issue
Block a user