diff --git a/bayesclass/_version.py b/bayesclass/_version.py index 3dc1f76..485f44a 100644 --- a/bayesclass/_version.py +++ b/bayesclass/_version.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index c47adca..e68d435 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -34,7 +34,8 @@ class BayesBase(BaseEstimator, ClassifierMixin): def nodes_leaves(self): """To keep compatiblity with the benchmark platform""" - return 0, 0 + nodes = len(self.dag_) if hasattr(self, "dag_") else 0 + return nodes, 0 def _check_params_fit(self, X, y, expected_args, kwargs): """Check the common parameters passed to fit""" @@ -57,6 +58,7 @@ class BayesBase(BaseEstimator, ClassifierMixin): raise ValueError( "Number of features does not match the number of columns in X" ) + self.n_features_in_ = X.shape[1] return X, y def fit(self, X, y, **kwargs): @@ -115,11 +117,12 @@ class BayesBase(BaseEstimator, ClassifierMixin): self.model_ = BayesianNetwork( self.dag_.edges(), show_progress=self.show_progress ) + states = dict(state_names=kwargs.pop("state_names", [])) self.model_.fit( self.dataset_, estimator=BayesianEstimator, prior_type="K2", - state_names=kwargs["state_names"], + **states, ) def predict(self, X): @@ -267,7 +270,7 @@ class KDB(BayesBase): 5.1. Select feature Xmax which is not in S and has the largest value I(Xmax;C). 5.2. Add a node to BN representing Xmax. 5.3. Add an arc from C to Xmax in BN. - 5.4. Add m =min(lSl,/c) arcs from m distinct features Xj in S with the highest value for I(Xmax;X,jC). + 5.4. Add m = min(lSl,/c) arcs from m distinct features Xj in S with the highest value for I(Xmax;X,jC). 5.5. Add Xmax to S. Compute the conditional probabilility infered by the structure of BN by using counts from DB, and output BN. """ @@ -331,8 +334,14 @@ class AODE(BayesBase, BaseEnsemble): expected_args = ["class_name", "features", "state_names"] return self._check_params_fit(X, y, expected_args, kwargs) - def _build(self): + def nodes_leaves(self): + """To keep compatiblity with the benchmark platform""" + nodes = 0 + if hasattr(self, "fitted_"): + nodes = sum([len(x) for x in self.models_]) + return nodes, 0 + def _build(self): self.dag_ = None def _train(self, kwargs): @@ -349,11 +358,12 @@ class AODE(BayesBase, BaseEnsemble): model = BayesianNetwork( feature_edges, show_progress=self.show_progress ) + states = dict(state_names=kwargs.pop("state_names", [])) model.fit( self.dataset_, estimator=BayesianEstimator, prior_type="K2", - state_names=kwargs["state_names"], + **states, ) self.models_.append(model) diff --git a/bayesclass/tests/test_AODE.py b/bayesclass/tests/test_AODE.py index 8c9a347..f8562f1 100644 --- a/bayesclass/tests/test_AODE.py +++ b/bayesclass/tests/test_AODE.py @@ -55,8 +55,10 @@ def test_AODE_version(clf): assert __version__ == clf.version() -def test_AODE_nodes_leaves(clf): +def test_AODE_nodes_leaves(clf, data): assert clf.nodes_leaves() == (0, 0) + clf.fit(*data) + assert clf.nodes_leaves() == (20, 0) def test_AODE_classifier(data, clf): diff --git a/bayesclass/tests/test_KDB.py b/bayesclass/tests/test_KDB.py index c304747..d776519 100644 --- a/bayesclass/tests/test_KDB.py +++ b/bayesclass/tests/test_KDB.py @@ -46,8 +46,10 @@ def test_KDB_version(clf): assert __version__ == clf.version() -def test_KDB_nodes_leaves(clf): +def test_KDB_nodes_leaves(clf, data): assert clf.nodes_leaves() == (0, 0) + clf.fit(*data) + assert clf.nodes_leaves() == (5, 0) def test_KDB_classifier(data, clf): diff --git a/bayesclass/tests/test_TAN.py b/bayesclass/tests/test_TAN.py index 355ccd8..5bb8283 100644 --- a/bayesclass/tests/test_TAN.py +++ b/bayesclass/tests/test_TAN.py @@ -45,8 +45,11 @@ def test_TAN_version(clf): assert __version__ == clf.version() -def test_TAN_nodes_leaves(clf): +def test_TAN_nodes_leaves(clf, data): assert clf.nodes_leaves() == (0, 0) + clf = TAN(random_state=17) + clf.fit(*data, head="random") + assert clf.nodes_leaves() == (5, 0) def test_TAN_random_head(data): diff --git a/bayesclass/tests/test_common.py b/bayesclass/tests/test_common.py index 917ad76..b5334d1 100644 --- a/bayesclass/tests/test_common.py +++ b/bayesclass/tests/test_common.py @@ -5,10 +5,12 @@ from sklearn.utils.estimator_checks import check_estimator from bayesclass.clfs import TAN, KDB, AODE -@pytest.mark.parametrize("estimator", [TAN(), KDB(k=2), AODE()]) -# @pytest.mark.parametrize("estimator", [AODE()]) -def test_all_estimators(estimator): +@pytest.mark.parametrize("estimators", [TAN(), KDB(k=2), AODE()]) +# @pytest.mark.parametrize("estimators", [TAN()]) + + +def test_all_estimators(estimators): i = 0 - for estimator, test in check_estimator(estimator, generate_only=True): + for estimator, test in check_estimator(estimators, generate_only=True): print(i := i + 1, test) # test(estimator)