mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-17 08:35:53 +00:00
feat: 🧐 Add nodes, edges and states info to models
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from scipy.stats import mode
|
from scipy.stats import mode
|
||||||
@@ -18,6 +19,9 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
def __init__(self, random_state, show_progress):
|
def __init__(self, random_state, show_progress):
|
||||||
self.random_state = random_state
|
self.random_state = random_state
|
||||||
self.show_progress = show_progress
|
self.show_progress = show_progress
|
||||||
|
# To keep compatiblity with the benchmark platform
|
||||||
|
self.nodes_leaves = self.nodes_edges
|
||||||
|
self.depth_ = self.states_
|
||||||
|
|
||||||
def _more_tags(self):
|
def _more_tags(self):
|
||||||
return {
|
return {
|
||||||
@@ -32,10 +36,10 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
"""Return the version of the package."""
|
"""Return the version of the package."""
|
||||||
return __version__
|
return __version__
|
||||||
|
|
||||||
def nodes_leaves(self):
|
def nodes_edges(self):
|
||||||
"""To keep compatiblity with the benchmark platform"""
|
if hasattr(self, "dag_"):
|
||||||
nodes = len(self.dag_) if hasattr(self, "dag_") else 0
|
return len(self.dag_), len(self.dag_.edges())
|
||||||
return nodes, 0
|
return 0, 0
|
||||||
|
|
||||||
def _check_params_fit(self, X, y, expected_args, kwargs):
|
def _check_params_fit(self, X, y, expected_args, kwargs):
|
||||||
"""Check the common parameters passed to fit"""
|
"""Check the common parameters passed to fit"""
|
||||||
@@ -61,6 +65,12 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
self.n_features_in_ = X.shape[1]
|
self.n_features_in_ = X.shape[1]
|
||||||
return X, y
|
return X, y
|
||||||
|
|
||||||
|
@property
|
||||||
|
def states_(self):
|
||||||
|
if hasattr(self, "fitted_"):
|
||||||
|
return sum([len(item) for _, item in self.model_.states.items()])
|
||||||
|
return 0
|
||||||
|
|
||||||
def fit(self, X, y, **kwargs):
|
def fit(self, X, y, **kwargs):
|
||||||
"""A reference implementation of a fitting function for a classifier.
|
"""A reference implementation of a fitting function for a classifier.
|
||||||
|
|
||||||
@@ -180,6 +190,7 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
return self.model_.predict(dataset).values.ravel()
|
return self.model_.predict(dataset).values.ravel()
|
||||||
|
|
||||||
def plot(self, title="", node_size=800):
|
def plot(self, title="", node_size=800):
|
||||||
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
nx.draw_circular(
|
nx.draw_circular(
|
||||||
self.model_,
|
self.model_,
|
||||||
with_labels=True,
|
with_labels=True,
|
||||||
@@ -240,12 +251,37 @@ class TAN(BayesBase):
|
|||||||
return X, y
|
return X, y
|
||||||
|
|
||||||
def _build(self):
|
def _build(self):
|
||||||
est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
# est = TreeSearch(self.dataset_, root_node=self.features_[self.head_])
|
||||||
self.dag_ = est.estimate(
|
# self.dag_ = est.estimate(
|
||||||
estimator_type="tan",
|
# estimator_type="tan",
|
||||||
class_node=self.class_name_,
|
# class_node=self.class_name_,
|
||||||
show_progress=self.show_progress,
|
# show_progress=self.show_progress,
|
||||||
|
# )
|
||||||
|
# Code taken from pgmpy
|
||||||
|
n_jobs = -1
|
||||||
|
weights = TreeSearch._get_conditional_weights(
|
||||||
|
self.dataset_,
|
||||||
|
self.class_name_,
|
||||||
|
"mutual_info",
|
||||||
|
n_jobs,
|
||||||
|
self.show_progress,
|
||||||
)
|
)
|
||||||
|
# Step 4.2: Construct chow-liu DAG on {data.columns - class_node}
|
||||||
|
class_node_idx = np.where(self.dataset_.columns == self.class_name_)[
|
||||||
|
0
|
||||||
|
][0]
|
||||||
|
weights = np.delete(weights, class_node_idx, axis=0)
|
||||||
|
weights = np.delete(weights, class_node_idx, axis=1)
|
||||||
|
reduced_columns = np.delete(self.dataset_.columns, class_node_idx)
|
||||||
|
D = TreeSearch._create_tree_and_dag(
|
||||||
|
weights, reduced_columns, self.features_[self.head_]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4.3: Add edges from class_node to all other nodes.
|
||||||
|
D.add_edges_from(
|
||||||
|
[(self.class_name_, node) for node in reduced_columns]
|
||||||
|
)
|
||||||
|
self.dag_ = D
|
||||||
|
|
||||||
|
|
||||||
class KDB(BayesBase):
|
class KDB(BayesBase):
|
||||||
@@ -331,12 +367,25 @@ class AODE(BayesBase, BaseEnsemble):
|
|||||||
expected_args = ["class_name", "features", "state_names"]
|
expected_args = ["class_name", "features", "state_names"]
|
||||||
return self._check_params_fit(X, y, expected_args, kwargs)
|
return self._check_params_fit(X, y, expected_args, kwargs)
|
||||||
|
|
||||||
def nodes_leaves(self):
|
def nodes_edges(self):
|
||||||
"""To keep compatiblity with the benchmark platform"""
|
|
||||||
nodes = 0
|
nodes = 0
|
||||||
|
edges = 0
|
||||||
if hasattr(self, "fitted_"):
|
if hasattr(self, "fitted_"):
|
||||||
nodes = sum([len(x) for x in self.models_])
|
nodes = sum([len(x) for x in self.models_])
|
||||||
return nodes, 0
|
edges = sum([len(x.edges()) for x in self.models_])
|
||||||
|
return nodes, edges
|
||||||
|
|
||||||
|
@property
|
||||||
|
def states_(self):
|
||||||
|
if hasattr(self, "fitted_"):
|
||||||
|
return sum(
|
||||||
|
[
|
||||||
|
len(item)
|
||||||
|
for model in self.models_
|
||||||
|
for _, item in model.states.items()
|
||||||
|
]
|
||||||
|
) / len(self.models_)
|
||||||
|
return 0
|
||||||
|
|
||||||
def _build(self):
|
def _build(self):
|
||||||
self.dag_ = None
|
self.dag_ = None
|
||||||
@@ -365,6 +414,7 @@ class AODE(BayesBase, BaseEnsemble):
|
|||||||
self.models_.append(model)
|
self.models_.append(model)
|
||||||
|
|
||||||
def plot(self, title=""):
|
def plot(self, title=""):
|
||||||
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
for idx, model in enumerate(self.models_):
|
for idx, model in enumerate(self.models_):
|
||||||
self.model_ = model
|
self.model_ = model
|
||||||
super().plot(title=f"{idx} {title}")
|
super().plot(title=f"{idx} {title}")
|
||||||
|
@@ -55,10 +55,17 @@ def test_AODE_version(clf):
|
|||||||
assert __version__ == clf.version()
|
assert __version__ == clf.version()
|
||||||
|
|
||||||
|
|
||||||
def test_AODE_nodes_leaves(clf, data):
|
def test_AODE_nodes_edges(clf, data):
|
||||||
assert clf.nodes_leaves() == (0, 0)
|
assert clf.nodes_leaves() == (0, 0)
|
||||||
clf.fit(*data)
|
clf.fit(*data)
|
||||||
assert clf.nodes_leaves() == (20, 0)
|
assert clf.nodes_leaves() == (20, 28)
|
||||||
|
|
||||||
|
|
||||||
|
def test_AODE_states(clf, data):
|
||||||
|
assert clf.states_ == 0
|
||||||
|
clf = AODE(random_state=17)
|
||||||
|
clf.fit(*data)
|
||||||
|
assert clf.states_ == 23
|
||||||
|
|
||||||
|
|
||||||
def test_AODE_classifier(data, clf):
|
def test_AODE_classifier(data, clf):
|
||||||
|
@@ -47,10 +47,17 @@ def test_KDB_version(clf):
|
|||||||
assert __version__ == clf.version()
|
assert __version__ == clf.version()
|
||||||
|
|
||||||
|
|
||||||
def test_KDB_nodes_leaves(clf, data):
|
def test_KDB_nodes_edges(clf, data):
|
||||||
assert clf.nodes_leaves() == (0, 0)
|
assert clf.nodes_leaves() == (0, 0)
|
||||||
clf.fit(*data)
|
clf.fit(*data)
|
||||||
assert clf.nodes_leaves() == (5, 0)
|
assert clf.nodes_leaves() == (5, 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_KDB_states(clf, data):
|
||||||
|
assert clf.states_ == 0
|
||||||
|
clf = KDB(k=3, random_state=17)
|
||||||
|
clf.fit(*data)
|
||||||
|
assert clf.states_ == 23
|
||||||
|
|
||||||
|
|
||||||
def test_KDB_classifier(data, clf):
|
def test_KDB_classifier(data, clf):
|
||||||
|
@@ -45,11 +45,18 @@ def test_TAN_version(clf):
|
|||||||
assert __version__ == clf.version()
|
assert __version__ == clf.version()
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_nodes_leaves(clf, data):
|
def test_TAN_nodes_edges(clf, data):
|
||||||
assert clf.nodes_leaves() == (0, 0)
|
assert clf.nodes_leaves() == (0, 0)
|
||||||
clf = TAN(random_state=17)
|
clf = TAN(random_state=17)
|
||||||
clf.fit(*data, head="random")
|
clf.fit(*data, head="random")
|
||||||
assert clf.nodes_leaves() == (5, 0)
|
assert clf.nodes_leaves() == (5, 7)
|
||||||
|
|
||||||
|
|
||||||
|
def test_TAN_states(clf, data):
|
||||||
|
assert clf.states_ == 0
|
||||||
|
clf = TAN(random_state=17)
|
||||||
|
clf.fit(*data)
|
||||||
|
assert clf.states_ == 23
|
||||||
|
|
||||||
|
|
||||||
def test_TAN_random_head(data):
|
def test_TAN_random_head(data):
|
||||||
|
Reference in New Issue
Block a user