From 4d416959adf7aa0b6f38f75d642babc362f2c3b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana=20G=C3=B3mez?= Date: Sun, 22 Jan 2023 14:15:19 +0100 Subject: [PATCH] fix: :bug: Fix depth_ property as an alias of states_ --- bayesclass/clfs.py | 5 ++++- bayesclass/tests/test_AODE.py | 1 + bayesclass/tests/test_KDB.py | 1 + bayesclass/tests/test_TAN.py | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index 7f94fa8..7973691 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -21,7 +21,6 @@ class BayesBase(BaseEstimator, ClassifierMixin): self.show_progress = show_progress # To keep compatiblity with the benchmark platform self.nodes_leaves = self.nodes_edges - self.depth_ = self.states_ def _more_tags(self): return { @@ -71,6 +70,10 @@ class BayesBase(BaseEstimator, ClassifierMixin): return sum([len(item) for _, item in self.model_.states.items()]) return 0 + @property + def depth_(self): + return self.states_ + def fit(self, X, y, **kwargs): """A reference implementation of a fitting function for a classifier. diff --git a/bayesclass/tests/test_AODE.py b/bayesclass/tests/test_AODE.py index 17eb683..ec3c908 100644 --- a/bayesclass/tests/test_AODE.py +++ b/bayesclass/tests/test_AODE.py @@ -66,6 +66,7 @@ def test_AODE_states(clf, data): clf = AODE(random_state=17) clf.fit(*data) assert clf.states_ == 23 + assert clf.depth_ == clf.states_ def test_AODE_classifier(data, clf): diff --git a/bayesclass/tests/test_KDB.py b/bayesclass/tests/test_KDB.py index 17d5729..2d40d14 100644 --- a/bayesclass/tests/test_KDB.py +++ b/bayesclass/tests/test_KDB.py @@ -58,6 +58,7 @@ def test_KDB_states(clf, data): clf = KDB(k=3, random_state=17) clf.fit(*data) assert clf.states_ == 23 + assert clf.depth_ == clf.states_ def test_KDB_classifier(data, clf): diff --git a/bayesclass/tests/test_TAN.py b/bayesclass/tests/test_TAN.py index ca4f8bf..d594bc8 100644 --- a/bayesclass/tests/test_TAN.py +++ b/bayesclass/tests/test_TAN.py @@ -57,6 +57,7 @@ def test_TAN_states(clf, data): clf = TAN(random_state=17) clf.fit(*data) assert clf.states_ == 23 + assert clf.depth_ == clf.states_ def test_TAN_random_head(data):