mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-17 16:45:54 +00:00
feat: 🧐 Add nodes, edges and states info to models
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import random
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats import mode
|
||||
@@ -18,6 +19,9 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
||||
def __init__(self, random_state, show_progress):
|
||||
self.random_state = random_state
|
||||
self.show_progress = show_progress
|
||||
# To keep compatiblity with the benchmark platform
|
||||
self.nodes_leaves = self.nodes_edges
|
||||
self.depth_ = self.states_
|
||||
|
||||
def _more_tags(self):
|
||||
return {
|
||||
@@ -32,10 +36,10 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
||||
"""Return the version of the package."""
|
||||
return __version__
|
||||
|
||||
def nodes_leaves(self):
|
||||
"""To keep compatiblity with the benchmark platform"""
|
||||
nodes = len(self.dag_) if hasattr(self, "dag_") else 0
|
||||
return nodes, 0
|
||||
def nodes_edges(self):
|
||||
if hasattr(self, "dag_"):
|
||||
return len(self.dag_), len(self.dag_.edges())
|
||||
return 0, 0
|
||||
|
||||
def _check_params_fit(self, X, y, expected_args, kwargs):
|
||||
"""Check the common parameters passed to fit"""
|
||||
@@ -61,6 +65,12 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
||||
self.n_features_in_ = X.shape[1]
|
||||
return X, y
|
||||
|
||||
@property
|
||||
def states_(self):
|
||||
if hasattr(self, "fitted_"):
|
||||
return sum([len(item) for _, item in self.model_.states.items()])
|
||||
return 0
|
||||
|
||||
def fit(self, X, y, **kwargs):
|
||||
"""A reference implementation of a fitting function for a classifier.
|
||||
|
||||
@@ -180,6 +190,7 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
||||
return self.model_.predict(dataset).values.ravel()
|
||||
|
||||
def plot(self, title="", node_size=800):
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
nx.draw_circular(
|
||||
self.model_,
|
||||
with_labels=True,
|
||||
@@ -240,12 +251,37 @@ class TAN(BayesBase):
|
||||
return X, y
|
||||
|
||||
def _build(self):
|
||||
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
||||
self.dag_ = est.estimate(
|
||||
estimator_type="tan",
|
||||
class_node=self.class_name_,
|
||||
show_progress=self.show_progress,
|
||||
# est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
||||
# self.dag_ = est.estimate(
|
||||
# estimator_type="tan",
|
||||
# class_node=self.class_name_,
|
||||
# show_progress=self.show_progress,
|
||||
# )
|
||||
# Code taken from pgmpy
|
||||
n_jobs = -1
|
||||
weights = TreeSearch._get_conditional_weights(
|
||||
self.dataset_,
|
||||
self.class_name_,
|
||||
"mutual_info",
|
||||
n_jobs,
|
||||
self.show_progress,
|
||||
)
|
||||
# Step 4.2: Construct chow-liu DAG on {data.columns - class_node}
|
||||
class_node_idx = np.where(self.dataset_.columns == self.class_name_)[
|
||||
0
|
||||
][0]
|
||||
weights = np.delete(weights, class_node_idx, axis=0)
|
||||
weights = np.delete(weights, class_node_idx, axis=1)
|
||||
reduced_columns = np.delete(self.dataset_.columns, class_node_idx)
|
||||
D = TreeSearch._create_tree_and_dag(
|
||||
weights, reduced_columns, self.features_[self.head_]
|
||||
)
|
||||
|
||||
# Step 4.3: Add edges from class_node to all other nodes.
|
||||
D.add_edges_from(
|
||||
[(self.class_name_, node) for node in reduced_columns]
|
||||
)
|
||||
self.dag_ = D
|
||||
|
||||
|
||||
class KDB(BayesBase):
|
||||
@@ -331,12 +367,25 @@ class AODE(BayesBase, BaseEnsemble):
|
||||
expected_args = ["class_name", "features", "state_names"]
|
||||
return self._check_params_fit(X, y, expected_args, kwargs)
|
||||
|
||||
def nodes_leaves(self):
|
||||
"""To keep compatiblity with the benchmark platform"""
|
||||
def nodes_edges(self):
|
||||
nodes = 0
|
||||
edges = 0
|
||||
if hasattr(self, "fitted_"):
|
||||
nodes = sum([len(x) for x in self.models_])
|
||||
return nodes, 0
|
||||
edges = sum([len(x.edges()) for x in self.models_])
|
||||
return nodes, edges
|
||||
|
||||
@property
|
||||
def states_(self):
|
||||
if hasattr(self, "fitted_"):
|
||||
return sum(
|
||||
[
|
||||
len(item)
|
||||
for model in self.models_
|
||||
for _, item in model.states.items()
|
||||
]
|
||||
) / len(self.models_)
|
||||
return 0
|
||||
|
||||
def _build(self):
|
||||
self.dag_ = None
|
||||
@@ -365,6 +414,7 @@ class AODE(BayesBase, BaseEnsemble):
|
||||
self.models_.append(model)
|
||||
|
||||
def plot(self, title=""):
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
for idx, model in enumerate(self.models_):
|
||||
self.model_ = model
|
||||
super().plot(title=f"{idx} {title}")
|
||||
|
@@ -55,10 +55,17 @@ def test_AODE_version(clf):
|
||||
assert __version__ == clf.version()
|
||||
|
||||
|
||||
def test_AODE_nodes_leaves(clf, data):
|
||||
def test_AODE_nodes_edges(clf, data):
|
||||
assert clf.nodes_leaves() == (0, 0)
|
||||
clf.fit(*data)
|
||||
assert clf.nodes_leaves() == (20, 0)
|
||||
assert clf.nodes_leaves() == (20, 28)
|
||||
|
||||
|
||||
def test_AODE_states(clf, data):
|
||||
assert clf.states_ == 0
|
||||
clf = AODE(random_state=17)
|
||||
clf.fit(*data)
|
||||
assert clf.states_ == 23
|
||||
|
||||
|
||||
def test_AODE_classifier(data, clf):
|
||||
|
@@ -47,10 +47,17 @@ def test_KDB_version(clf):
|
||||
assert __version__ == clf.version()
|
||||
|
||||
|
||||
def test_KDB_nodes_leaves(clf, data):
|
||||
def test_KDB_nodes_edges(clf, data):
|
||||
assert clf.nodes_leaves() == (0, 0)
|
||||
clf.fit(*data)
|
||||
assert clf.nodes_leaves() == (5, 0)
|
||||
assert clf.nodes_leaves() == (5, 10)
|
||||
|
||||
|
||||
def test_KDB_states(clf, data):
|
||||
assert clf.states_ == 0
|
||||
clf = KDB(k=3, random_state=17)
|
||||
clf.fit(*data)
|
||||
assert clf.states_ == 23
|
||||
|
||||
|
||||
def test_KDB_classifier(data, clf):
|
||||
|
@@ -45,11 +45,18 @@ def test_TAN_version(clf):
|
||||
assert __version__ == clf.version()
|
||||
|
||||
|
||||
def test_TAN_nodes_leaves(clf, data):
|
||||
def test_TAN_nodes_edges(clf, data):
|
||||
assert clf.nodes_leaves() == (0, 0)
|
||||
clf = TAN(random_state=17)
|
||||
clf.fit(*data, head="random")
|
||||
assert clf.nodes_leaves() == (5, 0)
|
||||
assert clf.nodes_leaves() == (5, 7)
|
||||
|
||||
|
||||
def test_TAN_states(clf, data):
|
||||
assert clf.states_ == 0
|
||||
clf = TAN(random_state=17)
|
||||
clf.fit(*data)
|
||||
assert clf.states_ == 23
|
||||
|
||||
|
||||
def test_TAN_random_head(data):
|
||||
|
Reference in New Issue
Block a user