diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index 7f94fa8..7973691 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -21,7 +21,6 @@ class BayesBase(BaseEstimator, ClassifierMixin): self.show_progress = show_progress # To keep compatiblity with the benchmark platform self.nodes_leaves = self.nodes_edges - self.depth_ = self.states_ def _more_tags(self): return { @@ -71,6 +70,10 @@ class BayesBase(BaseEstimator, ClassifierMixin): return sum([len(item) for _, item in self.model_.states.items()]) return 0 + @property + def depth_(self): + return self.states_ + def fit(self, X, y, **kwargs): """A reference implementation of a fitting function for a classifier. diff --git a/bayesclass/tests/test_AODE.py b/bayesclass/tests/test_AODE.py index 17eb683..ec3c908 100644 --- a/bayesclass/tests/test_AODE.py +++ b/bayesclass/tests/test_AODE.py @@ -66,6 +66,7 @@ def test_AODE_states(clf, data): clf = AODE(random_state=17) clf.fit(*data) assert clf.states_ == 23 + assert clf.depth_ == clf.states_ def test_AODE_classifier(data, clf): diff --git a/bayesclass/tests/test_KDB.py b/bayesclass/tests/test_KDB.py index 17d5729..2d40d14 100644 --- a/bayesclass/tests/test_KDB.py +++ b/bayesclass/tests/test_KDB.py @@ -58,6 +58,7 @@ def test_KDB_states(clf, data): clf = KDB(k=3, random_state=17) clf.fit(*data) assert clf.states_ == 23 + assert clf.depth_ == clf.states_ def test_KDB_classifier(data, clf): diff --git a/bayesclass/tests/test_TAN.py b/bayesclass/tests/test_TAN.py index ca4f8bf..d594bc8 100644 --- a/bayesclass/tests/test_TAN.py +++ b/bayesclass/tests/test_TAN.py @@ -57,6 +57,7 @@ def test_TAN_states(clf, data): clf = TAN(random_state=17) clf.fit(*data) assert clf.states_ == 23 + assert clf.depth_ == clf.states_ def test_TAN_random_head(data):