Add fit_params to model fit

This commit is contained in:
2022-12-28 19:15:34 +01:00
parent 3b4fc10a3e
commit e7300366ca

View File

@@ -106,12 +106,12 @@ class BayesBase(BaseEstimator, ClassifierMixin):
# Build the DAG
self._build()
# Train the model
self._train()
self._train(kwargs)
self.fitted_ = True
# Return the classifier
return self
def _train(self):
def _train(self, kwargs):
self.model_ = BayesianNetwork(
self.dag_.edges(), show_progress=self.show_progress
)
@@ -119,6 +119,7 @@ class BayesBase(BaseEstimator, ClassifierMixin):
self.dataset_,
estimator=BayesianEstimator,
prior_type="K2",
state_names=kwargs["state_names"],
)
def predict(self, X):
@@ -227,7 +228,7 @@ class TAN(BayesBase):
def _check_params(self, X, y, kwargs):
self.head_ = 0
expected_args = ["class_name", "features", "head"]
expected_args = ["class_name", "features", "head", "state_names"]
X, y = self._check_params_fit(X, y, expected_args, kwargs)
if self.head_ == "random":
self.head_ = random.randint(0, len(self.features_) - 1)
@@ -253,7 +254,7 @@ class KDB(BayesBase):
)
def _check_params(self, X, y, kwargs):
expected_args = ["class_name", "features"]
expected_args = ["class_name", "features", "state_names"]
return self._check_params_fit(X, y, expected_args, kwargs)
def _build(self):
@@ -327,7 +328,7 @@ class AODE(BayesBase, BaseEnsemble):
)
def _check_params(self, X, y, kwargs):
expected_args = ["class_name", "features"]
expected_args = ["class_name", "features", "state_names"]
return self._check_params_fit(X, y, expected_args, kwargs)
def _build(self):