mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-20 18:15:57 +00:00
Add example of usage
This commit is contained in:
@@ -7,13 +7,8 @@ import pandas as pd
|
||||
from sklearn.base import ClassifierMixin, BaseEstimator
|
||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||
from sklearn.utils.multiclass import unique_labels
|
||||
from sklearn.exceptions import NotFittedError
|
||||
import networkx as nx
|
||||
from pgmpy.estimators import (
|
||||
TreeSearch,
|
||||
BayesianEstimator,
|
||||
# MaximumLikelihoodEstimator,
|
||||
)
|
||||
from pgmpy.estimators import TreeSearch, BayesianEstimator
|
||||
from pgmpy.models import BayesianNetwork
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@@ -39,32 +34,13 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, simple_init=False, show_progress=False, random_state=None
|
||||
self, simple_init=True, show_progress=False, random_state=None
|
||||
):
|
||||
self.simple_init = simple_init
|
||||
self.show_progress = show_progress
|
||||
self.random_state = random_state
|
||||
|
||||
def fit(self, X, y, **kwargs):
|
||||
"""A reference implementation of a fitting function for a classifier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like, shape (n_samples, n_features)
|
||||
The training input samples.
|
||||
y : array-like, shape (n_samples,)
|
||||
The target values. An array of int.
|
||||
**kwargs : dict
|
||||
class_name : str (default='class') Name of the class column
|
||||
features: list (default=None) List of features
|
||||
head: int (default=None) Index of the head node. Default value
|
||||
gets the node with the highest sum of weights (mutual_info)
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : object
|
||||
Returns self.
|
||||
"""
|
||||
def __check_params_fit(self, X, y, kwargs):
|
||||
# Check that X and y have correct shape
|
||||
X, y = check_X_y(X, y)
|
||||
# Store the classes seen during fit
|
||||
@@ -90,16 +66,55 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
if self.head_ is not None and self.head_ >= len(self.features_):
|
||||
raise ValueError("Head index out of range")
|
||||
|
||||
def fit(self, X, y, **kwargs):
|
||||
"""A reference implementation of a fitting function for a classifier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array-like, shape (n_samples, n_features)
|
||||
The training input samples.
|
||||
y : array-like, shape (n_samples,)
|
||||
The target values. An array of int.
|
||||
**kwargs : dict
|
||||
class_name : str (default='class') Name of the class column
|
||||
features: list (default=None) List of features
|
||||
head: int (default=None) Index of the head node. Default value
|
||||
gets the node with the highest sum of weights (mutual_info)
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : object
|
||||
Returns self.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import numpy as np
|
||||
>>> import pandas as pd
|
||||
>>> from bayesclass import TAN
|
||||
>>> features = ['A', 'B', 'C', 'D', 'E']
|
||||
>>> np.random.seed(17)
|
||||
>>> values = pd.DataFrame(np.random.randint(low=0, high=2,
|
||||
... size=(1000, 5)), columns=features)
|
||||
>>> train_data = values[:800]
|
||||
>>> train_y = train_data['E']
|
||||
>>> predict_data = values[800:]
|
||||
>>> train_data = train_data.drop('E', axis=1)
|
||||
>>> model = TAN(random_state=17)
|
||||
>>> features.remove('E')
|
||||
>>> model.fit(train_data, train_y, features=features, class_name='E')
|
||||
TAN(random_state=17)
|
||||
"""
|
||||
self.__check_params_fit(X, y, kwargs)
|
||||
# Store the information needed to build the model
|
||||
self.X_ = X
|
||||
self.y_ = y.astype(int)
|
||||
self.dataset_ = pd.DataFrame(
|
||||
self.X_, columns=self.features_, dtype="int16"
|
||||
)
|
||||
self.dataset_[self.class_name_] = self.y_
|
||||
try:
|
||||
check_is_fitted(self, ["X_", "y_", "fitted_"])
|
||||
except NotFittedError:
|
||||
self.__build()
|
||||
# Build the DAG
|
||||
self.__build()
|
||||
# Train the model
|
||||
self.__train()
|
||||
self.fitted_ = True
|
||||
# Return the classifier
|
||||
@@ -145,24 +160,23 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
# learn graph structure
|
||||
root_node = None if self.head_ is None else self.features_[self.head_]
|
||||
est = TreeSearch(self.dataset_, root_node=root_node)
|
||||
dag = est.estimate(
|
||||
self.dag_ = est.estimate(
|
||||
estimator_type="tan",
|
||||
class_node=self.class_name_,
|
||||
show_progress=self.show_progress,
|
||||
)
|
||||
if self.head_ is None:
|
||||
self.head_ = est.root_node
|
||||
self.model_ = BayesianNetwork(
|
||||
dag.edges(), show_progress=self.show_progress
|
||||
)
|
||||
|
||||
def __train(self):
|
||||
self.model_ = BayesianNetwork(
|
||||
self.dag_.edges(), show_progress=self.show_progress
|
||||
)
|
||||
self.model_.fit(
|
||||
self.dataset_,
|
||||
# estimator=MaximumLikelihoodEstimator,
|
||||
estimator=BayesianEstimator,
|
||||
prior_type="K2",
|
||||
n_jobs=1,
|
||||
)
|
||||
|
||||
def plot(self, title=""):
|
||||
@@ -203,7 +217,7 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
>>> train_data = values[:800]
|
||||
>>> train_y = train_data['E']
|
||||
>>> predict_data = values[800:]
|
||||
>>> train_data.drop('E', axis=1, inplace=True)
|
||||
>>> train_data = train_data.drop('E', axis=1)
|
||||
>>> model = TAN(random_state=17)
|
||||
>>> features.remove('E')
|
||||
>>> model.fit(train_data, train_y, features=features, class_name='E')
|
||||
|
@@ -1 +0,0 @@
|
||||
m0 <- ulam(alist(height ~ dnorm(mu, sigma), mu <- a, a ~ dnorm(186, 10), sigma ~ dexp(1)), data = d, chains = 4, iter = 2000, cores = 4, log_lik=TRUE)
|
@@ -19,7 +19,7 @@ def data():
|
||||
def test_TAN_constructor():
|
||||
clf = TAN()
|
||||
# Test default values of hyperparameters
|
||||
assert not clf.simple_init
|
||||
assert clf.simple_init
|
||||
assert not clf.show_progress
|
||||
assert clf.random_state is None
|
||||
clf = TAN(simple_init=True, show_progress=True, random_state=17)
|
||||
@@ -34,6 +34,14 @@ def test_TAN_random_head(data):
|
||||
assert clf.head_ == 3
|
||||
|
||||
|
||||
def test_TAN_dag_initializer(data):
|
||||
clf_not_simple = TAN(simple_init=False)
|
||||
clf_simple = TAN(simple_init=True)
|
||||
clf_not_simple.fit(*data, head=0)
|
||||
clf_simple.fit(*data, head=0)
|
||||
assert clf_simple.dag_.edges == clf_not_simple.dag_.edges
|
||||
|
||||
|
||||
def test_TAN_classifier(data):
|
||||
clf = TAN()
|
||||
|
||||
|
Reference in New Issue
Block a user