diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index b00099a..25d1c9d 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -106,12 +106,12 @@ class BayesBase(BaseEstimator, ClassifierMixin): # Build the DAG self._build() # Train the model - self._train() + self._train(kwargs) self.fitted_ = True # Return the classifier return self - def _train(self): + def _train(self, kwargs): self.model_ = BayesianNetwork( self.dag_.edges(), show_progress=self.show_progress ) @@ -119,6 +119,7 @@ class BayesBase(BaseEstimator, ClassifierMixin): self.dataset_, estimator=BayesianEstimator, prior_type="K2", + state_names=kwargs["state_names"], ) def predict(self, X): @@ -227,7 +228,7 @@ class TAN(BayesBase): def _check_params(self, X, y, kwargs): self.head_ = 0 - expected_args = ["class_name", "features", "head"] + expected_args = ["class_name", "features", "head", "state_names"] X, y = self._check_params_fit(X, y, expected_args, kwargs) if self.head_ == "random": self.head_ = random.randint(0, len(self.features_) - 1) @@ -253,7 +254,7 @@ class KDB(BayesBase): ) def _check_params(self, X, y, kwargs): - expected_args = ["class_name", "features"] + expected_args = ["class_name", "features", "state_names"] return self._check_params_fit(X, y, expected_args, kwargs) def _build(self): @@ -327,7 +328,7 @@ class AODE(BayesBase, BaseEnsemble): ) def _check_params(self, X, y, kwargs): - expected_args = ["class_name", "features"] + expected_args = ["class_name", "features", "state_names"] return self._check_params_fit(X, y, expected_args, kwargs) def _build(self):