feat: 🧐 Add nodes, edges and states info to models

This commit is contained in:
Ricardo Montañana Gómez
2023-01-22 14:01:54 +01:00
parent 8fd796155d
commit bdd3f483d9
4 changed files with 89 additions and 18 deletions

View File

@@ -1,4 +1,5 @@
import random
import warnings
import numpy as np
import pandas as pd
from scipy.stats import mode
@@ -18,6 +19,9 @@ class BayesBase(BaseEstimator, ClassifierMixin):
def __init__(self, random_state, show_progress):
self.random_state = random_state
self.show_progress = show_progress
# To keep compatiblity with the benchmark platform
self.nodes_leaves = self.nodes_edges
self.depth_ = self.states_
def _more_tags(self):
return {
@@ -32,10 +36,10 @@ class BayesBase(BaseEstimator, ClassifierMixin):
"""Return the version of the package."""
return __version__
def nodes_leaves(self):
"""To keep compatiblity with the benchmark platform"""
nodes = len(self.dag_) if hasattr(self, "dag_") else 0
return nodes, 0
def nodes_edges(self):
if hasattr(self, "dag_"):
return len(self.dag_), len(self.dag_.edges())
return 0, 0
def _check_params_fit(self, X, y, expected_args, kwargs):
"""Check the common parameters passed to fit"""
@@ -61,6 +65,12 @@ class BayesBase(BaseEstimator, ClassifierMixin):
self.n_features_in_ = X.shape[1]
return X, y
@property
def states_(self):
if hasattr(self, "fitted_"):
return sum([len(item) for _, item in self.model_.states.items()])
return 0
def fit(self, X, y, **kwargs):
"""A reference implementation of a fitting function for a classifier.
@@ -180,6 +190,7 @@ class BayesBase(BaseEstimator, ClassifierMixin):
return self.model_.predict(dataset).values.ravel()
def plot(self, title="", node_size=800):
warnings.simplefilter("ignore", UserWarning)
nx.draw_circular(
self.model_,
with_labels=True,
@@ -240,12 +251,37 @@ class TAN(BayesBase):
return X, y
def _build(self):
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
self.dag_ = est.estimate(
estimator_type="tan",
class_node=self.class_name_,
show_progress=self.show_progress,
# est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
# self.dag_ = est.estimate(
# estimator_type="tan",
# class_node=self.class_name_,
# show_progress=self.show_progress,
# )
# Code taken from pgmpy
n_jobs = -1
weights = TreeSearch._get_conditional_weights(
self.dataset_,
self.class_name_,
"mutual_info",
n_jobs,
self.show_progress,
)
# Step 4.2: Construct chow-liu DAG on {data.columns - class_node}
class_node_idx = np.where(self.dataset_.columns == self.class_name_)[
0
][0]
weights = np.delete(weights, class_node_idx, axis=0)
weights = np.delete(weights, class_node_idx, axis=1)
reduced_columns = np.delete(self.dataset_.columns, class_node_idx)
D = TreeSearch._create_tree_and_dag(
weights, reduced_columns, self.features_[self.head_]
)
# Step 4.3: Add edges from class_node to all other nodes.
D.add_edges_from(
[(self.class_name_, node) for node in reduced_columns]
)
self.dag_ = D
class KDB(BayesBase):
@@ -331,12 +367,25 @@ class AODE(BayesBase, BaseEnsemble):
expected_args = ["class_name", "features", "state_names"]
return self._check_params_fit(X, y, expected_args, kwargs)
def nodes_leaves(self):
"""To keep compatiblity with the benchmark platform"""
def nodes_edges(self):
nodes = 0
edges = 0
if hasattr(self, "fitted_"):
nodes = sum([len(x) for x in self.models_])
return nodes, 0
edges = sum([len(x.edges()) for x in self.models_])
return nodes, edges
@property
def states_(self):
if hasattr(self, "fitted_"):
return sum(
[
len(item)
for model in self.models_
for _, item in model.states.items()
]
) / len(self.models_)
return 0
def _build(self):
self.dag_ = None
@@ -365,6 +414,7 @@ class AODE(BayesBase, BaseEnsemble):
self.models_.append(model)
def plot(self, title=""):
warnings.simplefilter("ignore", UserWarning)
for idx, model in enumerate(self.models_):
self.model_ = model
super().plot(title=f"{idx} {title}")

View File

@@ -55,10 +55,17 @@ def test_AODE_version(clf):
assert __version__ == clf.version()
def test_AODE_nodes_leaves(clf, data):
def test_AODE_nodes_edges(clf, data):
assert clf.nodes_leaves() == (0, 0)
clf.fit(*data)
assert clf.nodes_leaves() == (20, 0)
assert clf.nodes_leaves() == (20, 28)
def test_AODE_states(clf, data):
assert clf.states_ == 0
clf = AODE(random_state=17)
clf.fit(*data)
assert clf.states_ == 23
def test_AODE_classifier(data, clf):

View File

@@ -47,10 +47,17 @@ def test_KDB_version(clf):
assert __version__ == clf.version()
def test_KDB_nodes_leaves(clf, data):
def test_KDB_nodes_edges(clf, data):
assert clf.nodes_leaves() == (0, 0)
clf.fit(*data)
assert clf.nodes_leaves() == (5, 0)
assert clf.nodes_leaves() == (5, 10)
def test_KDB_states(clf, data):
assert clf.states_ == 0
clf = KDB(k=3, random_state=17)
clf.fit(*data)
assert clf.states_ == 23
def test_KDB_classifier(data, clf):

View File

@@ -45,11 +45,18 @@ def test_TAN_version(clf):
assert __version__ == clf.version()
def test_TAN_nodes_leaves(clf, data):
def test_TAN_nodes_edges(clf, data):
assert clf.nodes_leaves() == (0, 0)
clf = TAN(random_state=17)
clf.fit(*data, head="random")
assert clf.nodes_leaves() == (5, 0)
assert clf.nodes_leaves() == (5, 7)
def test_TAN_states(clf, data):
assert clf.states_ == 0
clf = TAN(random_state=17)
clf.fit(*data)
assert clf.states_ == 23
def test_TAN_random_head(data):