diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index c28066d..6dc7a32 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -39,6 +39,14 @@ class BayesBase(BaseEstimator, ClassifierMixin): return len(self.dag_), len(self.dag_.edges()) return 0, 0 + @staticmethod + def default_feature_names(num_features): + return [f"feature_{i}" for i in range(num_features)] + + @staticmethod + def default_class_name(): + return "class" + def _check_params_fit(self, X, y, expected_args, kwargs): """Check the common parameters passed to fit""" # Check that X and y have correct shape @@ -48,8 +56,8 @@ class BayesBase(BaseEstimator, ClassifierMixin): self.classes_ = unique_labels(y) self.n_classes_ = self.classes_.shape[0] # Default values - self.class_name_ = "class" - self.features_ = [f"feature_{i}" for i in range(X.shape[1])] + self.class_name_ = self.default_class_name() + self.features_ = self.default_feature_names(X.shape[1]) for key, value in kwargs.items(): if key in expected_args: setattr(self, f"{key}_", value) @@ -458,7 +466,11 @@ class KDBNew(KDB): def fit(self, X, y, **kwargs): self.discretizer_ = FImdlp(n_jobs=1) Xd = self.discretizer_.fit_transform(X, y) - features = kwargs["features"] + features = ( + kwargs["features"] + if "features" in kwargs + else self.default_feature_names(Xd.shape[1]) + ) self.compute_kwargs(Xd, y, kwargs) # Build the model super().fit(Xd, y, **kwargs) @@ -475,7 +487,12 @@ class KDBNew(KDB): features[i]: np.unique(Xd[:, i]).tolist() for i in range(Xd.shape[1]) } - states[kwargs["class_name"]] = np.unique(y).tolist() + class_name = ( + kwargs["class_name"] + if "class_name" in kwargs + else self.default_class_name() + ) + states[class_name] = np.unique(y).tolist() kwargs["state_names"] = states self.kwargs_ = kwargs