default_features_class_name

This commit is contained in:
2023-02-05 20:18:44 +01:00
parent 2a6547c71d
commit 3e049ac89d

View File

@@ -39,6 +39,14 @@ class BayesBase(BaseEstimator, ClassifierMixin):
return len(self.dag_), len(self.dag_.edges())
return 0, 0
@staticmethod
def default_feature_names(num_features):
return [f"feature_{i}" for i in range(num_features)]
@staticmethod
def default_class_name():
return "class"
def _check_params_fit(self, X, y, expected_args, kwargs):
"""Check the common parameters passed to fit"""
# Check that X and y have correct shape
@@ -48,8 +56,8 @@ class BayesBase(BaseEstimator, ClassifierMixin):
self.classes_ = unique_labels(y)
self.n_classes_ = self.classes_.shape[0]
# Default values
self.class_name_ = "class"
self.features_ = [f"feature_{i}" for i in range(X.shape[1])]
self.class_name_ = self.default_class_name()
self.features_ = self.default_feature_names(X.shape[1])
for key, value in kwargs.items():
if key in expected_args:
setattr(self, f"{key}_", value)
@@ -458,7 +466,11 @@ class KDBNew(KDB):
def fit(self, X, y, **kwargs):
self.discretizer_ = FImdlp(n_jobs=1)
Xd = self.discretizer_.fit_transform(X, y)
features = kwargs["features"]
features = (
kwargs["features"]
if "features" in kwargs
else self.default_feature_names(Xd.shape[1])
)
self.compute_kwargs(Xd, y, kwargs)
# Build the model
super().fit(Xd, y, **kwargs)
@@ -475,7 +487,12 @@ class KDBNew(KDB):
features[i]: np.unique(Xd[:, i]).tolist()
for i in range(Xd.shape[1])
}
states[kwargs["class_name"]] = np.unique(y).tolist()
class_name = (
kwargs["class_name"]
if "class_name" in kwargs
else self.default_class_name()
)
states[class_name] = np.unique(y).tolist()
kwargs["state_names"] = states
self.kwargs_ = kwargs