mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-20 18:15:57 +00:00
Update head hyperparam to use highest weight
This commit is contained in:
@@ -1,12 +1,18 @@
|
||||
"""
|
||||
This is a module to be used as a reference for building other modules
|
||||
"""
|
||||
import random
|
||||
from itertools import combinations
|
||||
import pandas as pd
|
||||
from sklearn.base import ClassifierMixin, BaseEstimator
|
||||
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted
|
||||
from sklearn.utils.multiclass import unique_labels
|
||||
import networkx as nx
|
||||
from pgmpy.estimators import TreeSearch, BayesianEstimator
|
||||
from pgmpy.estimators import (
|
||||
TreeSearch,
|
||||
BayesianEstimator,
|
||||
MaximumLikelihoodEstimator,
|
||||
)
|
||||
from pgmpy.models import BayesianNetwork
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
@@ -29,9 +35,12 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
The classes seen at :meth:`fit`.
|
||||
"""
|
||||
|
||||
def __init__(self, simple_init=False, show_progress=False):
|
||||
def __init__(
|
||||
self, simple_init=False, show_progress=False, random_state=None
|
||||
):
|
||||
self.simple_init = simple_init
|
||||
self.show_progress = show_progress
|
||||
self.random_state = random_state
|
||||
|
||||
def fit(self, X, y, **kwargs):
|
||||
"""A reference implementation of a fitting function for a classifier.
|
||||
@@ -44,7 +53,8 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
**kwargs : dict
|
||||
class_name : str (default='class') Name of the class column
|
||||
features: list (default=None) List of features
|
||||
head: int (default=0) Index of the head node
|
||||
head: int (default=None) Index of the head node. Default value
|
||||
gets the node with the highest sum of weights (mutual_info)
|
||||
Returns
|
||||
-------
|
||||
self : object
|
||||
@@ -57,20 +67,22 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
# Default values
|
||||
self.class_name_ = "class"
|
||||
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
|
||||
self.head_ = 0
|
||||
self.head_ = None
|
||||
expected_args = ["class_name", "features", "head"]
|
||||
for key, value in kwargs.items():
|
||||
if key in expected_args:
|
||||
setattr(self, f"{key}_", value)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected argument: {key}")
|
||||
|
||||
if self.random_state is not None:
|
||||
random.seed(self.random_state)
|
||||
if self.head_ == "random":
|
||||
self.head_ = random.randint(0, len(self.features_) - 1)
|
||||
if len(self.features_) != X.shape[1]:
|
||||
raise ValueError(
|
||||
"Number of features does not match the number of columns in X"
|
||||
)
|
||||
if self.head_ >= len(self.features_):
|
||||
if self.head_ is not None and self.head_ >= len(self.features_):
|
||||
raise ValueError("Head index out of range")
|
||||
|
||||
self.X_ = X
|
||||
@@ -80,37 +92,57 @@ class TAN(ClassifierMixin, BaseEstimator):
|
||||
return self
|
||||
|
||||
def __initial_edges(self):
|
||||
"""As with the naive Bayes, in a TAN structure, the class has no
|
||||
parents, while features must have the class as parent and are forced to
|
||||
have one other feature as parent too (except for one single feature,
|
||||
which has only the class as parent and is considered the root of the
|
||||
features' tree)
|
||||
Cassio P. de Campos, Giorgio Corani, Mauro Scanagatta, Marco Cuccu,
|
||||
Marco Zaffalon,
|
||||
Learning extended tree augmented naive structures,
|
||||
International Journal of Approximate Reasoning,
|
||||
Returns
|
||||
-------
|
||||
List
|
||||
List of edges
|
||||
"""
|
||||
head = 0 if self.head_ is None else self.head_
|
||||
if self.simple_init:
|
||||
first_node = self.features_[self.head_]
|
||||
first_node = self.features_[head]
|
||||
return [
|
||||
(first_node, feature)
|
||||
for feature in self.features_
|
||||
if feature != first_node
|
||||
]
|
||||
edges = []
|
||||
for i in range(len(self.features_)):
|
||||
for j in range(i + 1, len(self.features_)):
|
||||
edges.append((self.features_[i], self.features_[j]))
|
||||
return edges
|
||||
# initialize a complete network with all edges starting from head
|
||||
reordered = [
|
||||
self.features_[idx % len(self.features_)]
|
||||
for idx in range(head, len(self.features_) + head)
|
||||
]
|
||||
return list(combinations(reordered, 2))
|
||||
|
||||
def __train(self):
|
||||
# Initialize a Naive Bayes model
|
||||
net = [(self.class_name_, feature) for feature in self.features_]
|
||||
self.model_ = BayesianNetwork(net)
|
||||
# initialize a complete network with all edges
|
||||
self.model_.add_edges_from(self.__initial_edges())
|
||||
|
||||
self.dataset_ = pd.DataFrame(self.X_, columns=self.features_)
|
||||
self.dataset_[self.class_name_] = self.y_
|
||||
# learn graph structure
|
||||
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
||||
root_node = None if self.head_ is None else self.features_[self.head_]
|
||||
est = TreeSearch(self.dataset_, root_node=root_node)
|
||||
dag = est.estimate(
|
||||
estimator_type="tan",
|
||||
class_node=self.class_name_,
|
||||
show_progress=self.show_progress,
|
||||
)
|
||||
if self.head_ is None:
|
||||
self.head_ = est.root_node
|
||||
self.model_ = BayesianNetwork(dag.edges())
|
||||
self.model_.fit(
|
||||
self.dataset_,
|
||||
# estimator=MaximumLikelihoodEstimator,
|
||||
estimator=BayesianEstimator,
|
||||
prior_type="K2",
|
||||
)
|
||||
|
@@ -16,12 +16,26 @@ def data():
|
||||
return enc.fit_transform(X), y
|
||||
|
||||
|
||||
def test_TAN_classifier(data):
|
||||
def test_TAN_constructor():
|
||||
clf = TAN()
|
||||
|
||||
# Test default values of hyperparameters
|
||||
assert not clf.simple_init
|
||||
assert not clf.show_progress
|
||||
assert clf.random_state is None
|
||||
clf = TAN(simple_init=True, show_progress=True, random_state=17)
|
||||
assert clf.simple_init
|
||||
assert clf.show_progress
|
||||
assert clf.random_state == 17
|
||||
|
||||
|
||||
def test_TAN_random_head(data):
|
||||
clf = TAN(random_state=17)
|
||||
clf.fit(*data, head="random")
|
||||
assert clf.head_ == 3
|
||||
|
||||
|
||||
def test_TAN_classifier(data):
|
||||
clf = TAN()
|
||||
|
||||
clf.fit(*data)
|
||||
attribs = ["classes_", "X_", "y_", "head_", "features_", "class_name_"]
|
||||
|
Reference in New Issue
Block a user