Begin with kDB classifier

This commit is contained in:
2022-11-14 14:03:45 +01:00
parent 6b2e60eba0
commit 1545ca62cf
4 changed files with 168 additions and 133 deletions

View File

@@ -3,6 +3,7 @@ This is a module to be used as a reference for building other modules
"""
import random
from itertools import combinations
import numpy as np
import pandas as pd
from sklearn.base import ClassifierMixin, BaseEstimator
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
@@ -14,36 +15,8 @@ import matplotlib.pyplot as plt
from ._version import __version__
class TAN(ClassifierMixin, BaseEstimator):
"""An example classifier which implements a 1-NN algorithm.
For more information regarding how to build your own classifier, read more
in the :ref:`User Guide <user_guide>`.
Parameters
----------
demo_param : str, default='demo'
A parameter used for demonstation of how to pass and store paramters.
Attributes
----------
X_ : ndarray, shape (n_samples, n_features)
The input passed during :meth:`fit`.
y_ : ndarray, shape (n_samples,)
The labels passed during :meth:`fit`.
classes_ : ndarray, shape (n_classes,)
The classes seen at :meth:`fit`.
"""
def __init__(
self, simple_init=True, show_progress=False, random_state=None
):
self.simple_init = simple_init
self.show_progress = show_progress
self.random_state = random_state
class BayesBase(BaseEstimator, ClassifierMixin):
def _more_tags(self):
import numpy as np
return {
"requires_positive_X": True,
"requires_positive_y": True,
@@ -56,32 +29,9 @@ class TAN(ClassifierMixin, BaseEstimator):
"""Return the version of the package."""
return __version__
def __check_params_fit(self, X, y, kwargs):
# Check that X and y have correct shape
X, y = check_X_y(X, y)
# Store the classes seen during fit
self.classes_ = unique_labels(y)
# Default values
self.class_name_ = "class"
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
self.head_ = 0
expected_args = ["class_name", "features", "head"]
for key, value in kwargs.items():
if key in expected_args:
setattr(self, f"{key}_", value)
else:
raise ValueError(f"Unexpected argument: {key}")
if self.random_state is not None:
random.seed(self.random_state)
if self.head_ == "random":
self.head_ = random.randint(0, len(self.features_) - 1)
if len(self.features_) != X.shape[1]:
raise ValueError(
"Number of features does not match the number of columns in X"
)
if self.head_ is not None and self.head_ >= len(self.features_):
raise ValueError("Head index out of range")
return X, y
def nodes_leaves(self):
"""To keep compatiblity with the benchmark platform"""
return 0, 0
def fit(self, X, y, **kwargs):
"""A reference implementation of a fitting function for a classifier.
@@ -121,89 +71,47 @@ class TAN(ClassifierMixin, BaseEstimator):
>>> model.fit(train_data, train_y, features=features, class_name='E')
TAN(random_state=17)
"""
X_, y_ = self.__check_params_fit(X, y, kwargs)
X_, y_ = self._check_params_fit(X, y, kwargs)
# Store the information needed to build the model
self.X_ = X_
self.y_ = y_
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
self.dataset_[self.class_name_] = self.y_
# Build the DAG
self.__build()
self._build()
# Train the model
self.__train()
self._train()
self.fitted_ = True
# Return the classifier
return self
def __initial_edges(self):
"""As with the naive Bayes, in a TAN structure, the class has no
parents, while features must have the class as parent and are forced to
have one other feature as parent too (except for one single feature,
which has only the class as parent and is considered the root of the
features' tree)
Cassio P. de Campos, Giorgio Corani, Mauro Scanagatta, Marco Cuccu,
Marco Zaffalon,
Learning extended tree augmented naive structures,
International Journal of Approximate Reasoning,
Returns
-------
List
List of edges
"""
head = self.head_
if self.simple_init:
first_node = self.features_[head]
return [
(first_node, feature)
for feature in self.features_
if feature != first_node
]
# initialize a complete network with all edges starting from head
reordered = [
self.features_[idx % len(self.features_)]
for idx in range(head, len(self.features_) + head)
]
return list(combinations(reordered, 2))
def __build(self):
# Initialize a Naive Bayes model
net = [(self.class_name_, feature) for feature in self.features_]
self.model_ = BayesianNetwork(net)
# initialize a complete network with all edges
self.model_.add_edges_from(self.__initial_edges())
# learn graph structure
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
self.dag_ = est.estimate(
estimator_type="tan",
class_node=self.class_name_,
show_progress=self.show_progress,
)
def __train(self):
self.model_ = BayesianNetwork(
self.dag_.edges(), show_progress=self.show_progress
)
self.model_.fit(
self.dataset_,
estimator=BayesianEstimator,
prior_type="K2",
)
def nodes_leaves(self):
return 0, 0
def plot(self, title=""):
nx.draw_circular(
self.model_,
with_labels=True,
arrowsize=30,
node_size=800,
alpha=0.3,
font_weight="bold",
)
plt.title(title)
plt.show()
def _check_params_fit(self, X, y, kwargs):
"""Check the parameters passed to fit"""
# Check that X and y have correct shape
X, y = check_X_y(X, y)
# Store the classes seen during fit
self.classes_ = unique_labels(y)
# Default values
self.class_name_ = "class"
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
self.head_ = 0
expected_args = ["class_name", "features", "head"]
for key, value in kwargs.items():
if key in expected_args:
setattr(self, f"{key}_", value)
else:
raise ValueError(f"Unexpected argument: {key}")
if self.random_state is not None:
random.seed(self.random_state)
if self.head_ == "random":
self.head_ = random.randint(0, len(self.features_) - 1)
if len(self.features_) != X.shape[1]:
raise ValueError(
"Number of features does not match the number of columns in X"
)
if self.head_ is not None and self.head_ >= len(self.features_):
raise ValueError("Head index out of range")
return X, y
def predict(self, X):
"""A reference implementation of a prediction for a classifier.
@@ -257,4 +165,129 @@ class TAN(ClassifierMixin, BaseEstimator):
# Input validation
X = check_array(X)
dataset = pd.DataFrame(X, columns=self.features_, dtype="int16")
return self.model_.predict(dataset, n_jobs=1).to_numpy()
return self.model_.predict(dataset).values.ravel()
class TAN(BayesBase):
"""Tree Augmented Naive Bayes
Parameters
----------
simple_init : bool, default=True
How to init the initial DAG. If True, only the first feature is used
as father of the other features.
random_state: int, default=None
Random state for reproducibility
Attributes
----------
X_ : ndarray, shape (n_samples, n_features)
The input passed during :meth:`fit`.
y_ : ndarray, shape (n_samples,)
The labels passed during :meth:`fit`.
classes_ : ndarray, shape (n_classes,)
The classes seen at :meth:`fit`.
class_name_ : str
The name of the class column
features_ : list
The list of features names
head_ : int
The index of the node used as head for the initial DAG
dataset_ : pd.DataFrame
The dataset used to train the model (X_ + y_)
dag_ : nx.DiGraph
The TAN DAG
model_ : BayesianNetwork
The actual classifier
"""
def __init__(
self, simple_init=True, show_progress=False, random_state=None
):
self.simple_init = simple_init
self.show_progress = show_progress
self.random_state = random_state
def __initial_edges(self):
"""As with the naive Bayes, in a TAN structure, the class has no
parents, while features must have the class as parent and are forced to
have one other feature as parent too (except for one single feature,
which has only the class as parent and is considered the root of the
features' tree)
Cassio P. de Campos, Giorgio Corani, Mauro Scanagatta, Marco Cuccu,
Marco Zaffalon,
Learning extended tree augmented naive structures,
International Journal of Approximate Reasoning,
Returns
-------
List
List of edges
"""
head = self.head_
if self.simple_init:
first_node = self.features_[head]
return [
(first_node, feature)
for feature in self.features_
if feature != first_node
]
# initialize a complete network with all edges starting from head
reordered = [
self.features_[idx % len(self.features_)]
for idx in range(head, len(self.features_) + head)
]
return list(combinations(reordered, 2))
def _build(self):
# Initialize a Naive Bayes model
net = [(self.class_name_, feature) for feature in self.features_]
self.model_ = BayesianNetwork(net)
# initialize a complete network with all edges
self.model_.add_edges_from(self.__initial_edges())
# learn graph structure
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
self.dag_ = est.estimate(
estimator_type="tan",
class_node=self.class_name_,
show_progress=self.show_progress,
)
def _train(self):
self.model_ = BayesianNetwork(
self.dag_.edges(), show_progress=self.show_progress
)
self.model_.fit(
self.dataset_,
estimator=BayesianEstimator,
prior_type="K2",
)
def plot(self, title=""):
nx.draw_circular(
self.model_,
with_labels=True,
arrowsize=30,
node_size=800,
alpha=0.3,
font_weight="bold",
)
plt.title(title)
plt.show()
class KDBayesClassifier(BayesBase):
def __init__(self, k=3, random_state=None):
self.k = k
self.random_state = random_state
@staticmethod
def version() -> str:
"""Return the version of the package."""
return __version__
def _build(self):
pass
def _train(self):
pass

View File

@@ -69,8 +69,7 @@ def test_TAN_classifier(data):
X = data[0]
y = data[1]
y_pred = clf.predict(X)
y = y.reshape(-1, 1)
assert y_pred.shape == (X.shape[0], 1)
assert y_pred.shape == (X.shape[0],)
assert sum(y == y_pred) == 147
@@ -103,8 +102,7 @@ def test_TAN_classifier_simple_init(data):
X = data[0]
y = data[1]
y_pred = clf.predict(X)
y = y.reshape(-1, 1)
assert y_pred.shape == (X.shape[0], 1)
assert y_pred.shape == (X.shape[0],)
assert sum(y == y_pred) == 147

View File

@@ -7,4 +7,7 @@ from bayesclass import TAN
@pytest.mark.parametrize("estimator", [TAN()])
def test_all_estimators(estimator):
return check_estimator(estimator)
i = 0
for estimator, test in check_estimator(estimator, generate_only=True):
print(i := i + 1, test, "classes_")
# test(estimator)

View File

@@ -51,6 +51,7 @@ source = ["bayesclass"]
[tool.black]
line-length = 79
target_version = ['py38', 'py39', 'py310']
include = '\.pyi?$'
exclude = '''
/(