diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index 5c0edd4..7f94fa8 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -1,4 +1,5 @@ import random +import warnings import numpy as np import pandas as pd from scipy.stats import mode @@ -18,6 +19,9 @@ class BayesBase(BaseEstimator, ClassifierMixin): def __init__(self, random_state, show_progress): self.random_state = random_state self.show_progress = show_progress + # To keep compatiblity with the benchmark platform + self.nodes_leaves = self.nodes_edges + self.depth_ = self.states_ def _more_tags(self): return { @@ -32,10 +36,10 @@ class BayesBase(BaseEstimator, ClassifierMixin): """Return the version of the package.""" return __version__ - def nodes_leaves(self): - """To keep compatiblity with the benchmark platform""" - nodes = len(self.dag_) if hasattr(self, "dag_") else 0 - return nodes, 0 + def nodes_edges(self): + if hasattr(self, "dag_"): + return len(self.dag_), len(self.dag_.edges()) + return 0, 0 def _check_params_fit(self, X, y, expected_args, kwargs): """Check the common parameters passed to fit""" @@ -61,6 +65,12 @@ class BayesBase(BaseEstimator, ClassifierMixin): self.n_features_in_ = X.shape[1] return X, y + @property + def states_(self): + if hasattr(self, "fitted_"): + return sum([len(item) for _, item in self.model_.states.items()]) + return 0 + def fit(self, X, y, **kwargs): """A reference implementation of a fitting function for a classifier. @@ -180,6 +190,7 @@ class BayesBase(BaseEstimator, ClassifierMixin): return self.model_.predict(dataset).values.ravel() def plot(self, title="", node_size=800): + warnings.simplefilter("ignore", UserWarning) nx.draw_circular( self.model_, with_labels=True, @@ -240,12 +251,37 @@ class TAN(BayesBase): return X, y def _build(self): - est = TreeSearch(self.dataset_, root_node=self.features_[self.head_]) - self.dag_ = est.estimate( - estimator_type="tan", - class_node=self.class_name_, - show_progress=self.show_progress, + # est = TreeSearch(self.dataset_, root_node=self.features_[self.head_]) + # self.dag_ = est.estimate( + # estimator_type="tan", + # class_node=self.class_name_, + # show_progress=self.show_progress, + # ) + # Code taken from pgmpy + n_jobs = -1 + weights = TreeSearch._get_conditional_weights( + self.dataset_, + self.class_name_, + "mutual_info", + n_jobs, + self.show_progress, ) + # Step 4.2: Construct chow-liu DAG on {data.columns - class_node} + class_node_idx = np.where(self.dataset_.columns == self.class_name_)[ + 0 + ][0] + weights = np.delete(weights, class_node_idx, axis=0) + weights = np.delete(weights, class_node_idx, axis=1) + reduced_columns = np.delete(self.dataset_.columns, class_node_idx) + D = TreeSearch._create_tree_and_dag( + weights, reduced_columns, self.features_[self.head_] + ) + + # Step 4.3: Add edges from class_node to all other nodes. + D.add_edges_from( + [(self.class_name_, node) for node in reduced_columns] + ) + self.dag_ = D class KDB(BayesBase): @@ -331,12 +367,25 @@ class AODE(BayesBase, BaseEnsemble): expected_args = ["class_name", "features", "state_names"] return self._check_params_fit(X, y, expected_args, kwargs) - def nodes_leaves(self): - """To keep compatiblity with the benchmark platform""" + def nodes_edges(self): nodes = 0 + edges = 0 if hasattr(self, "fitted_"): nodes = sum([len(x) for x in self.models_]) - return nodes, 0 + edges = sum([len(x.edges()) for x in self.models_]) + return nodes, edges + + @property + def states_(self): + if hasattr(self, "fitted_"): + return sum( + [ + len(item) + for model in self.models_ + for _, item in model.states.items() + ] + ) / len(self.models_) + return 0 def _build(self): self.dag_ = None @@ -365,6 +414,7 @@ class AODE(BayesBase, BaseEnsemble): self.models_.append(model) def plot(self, title=""): + warnings.simplefilter("ignore", UserWarning) for idx, model in enumerate(self.models_): self.model_ = model super().plot(title=f"{idx} {title}") diff --git a/bayesclass/tests/test_AODE.py b/bayesclass/tests/test_AODE.py index f8562f1..17eb683 100644 --- a/bayesclass/tests/test_AODE.py +++ b/bayesclass/tests/test_AODE.py @@ -55,10 +55,17 @@ def test_AODE_version(clf): assert __version__ == clf.version() -def test_AODE_nodes_leaves(clf, data): +def test_AODE_nodes_edges(clf, data): assert clf.nodes_leaves() == (0, 0) clf.fit(*data) - assert clf.nodes_leaves() == (20, 0) + assert clf.nodes_leaves() == (20, 28) + + +def test_AODE_states(clf, data): + assert clf.states_ == 0 + clf = AODE(random_state=17) + clf.fit(*data) + assert clf.states_ == 23 def test_AODE_classifier(data, clf): diff --git a/bayesclass/tests/test_KDB.py b/bayesclass/tests/test_KDB.py index 10f93d4..17d5729 100644 --- a/bayesclass/tests/test_KDB.py +++ b/bayesclass/tests/test_KDB.py @@ -47,10 +47,17 @@ def test_KDB_version(clf): assert __version__ == clf.version() -def test_KDB_nodes_leaves(clf, data): +def test_KDB_nodes_edges(clf, data): assert clf.nodes_leaves() == (0, 0) clf.fit(*data) - assert clf.nodes_leaves() == (5, 0) + assert clf.nodes_leaves() == (5, 10) + + +def test_KDB_states(clf, data): + assert clf.states_ == 0 + clf = KDB(k=3, random_state=17) + clf.fit(*data) + assert clf.states_ == 23 def test_KDB_classifier(data, clf): diff --git a/bayesclass/tests/test_TAN.py b/bayesclass/tests/test_TAN.py index 5bb8283..ca4f8bf 100644 --- a/bayesclass/tests/test_TAN.py +++ b/bayesclass/tests/test_TAN.py @@ -45,11 +45,18 @@ def test_TAN_version(clf): assert __version__ == clf.version() -def test_TAN_nodes_leaves(clf, data): +def test_TAN_nodes_edges(clf, data): assert clf.nodes_leaves() == (0, 0) clf = TAN(random_state=17) clf.fit(*data, head="random") - assert clf.nodes_leaves() == (5, 0) + assert clf.nodes_leaves() == (5, 7) + + +def test_TAN_states(clf, data): + assert clf.states_ == 0 + clf = TAN(random_state=17) + clf.fit(*data) + assert clf.states_ == 23 def test_TAN_random_head(data):