mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-21 10:35:54 +00:00
Add fit_params to model fit
This commit is contained in:
@@ -106,12 +106,12 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
||||
# Build the DAG
|
||||
self._build()
|
||||
# Train the model
|
||||
self._train()
|
||||
self._train(kwargs)
|
||||
self.fitted_ = True
|
||||
# Return the classifier
|
||||
return self
|
||||
|
||||
def _train(self):
|
||||
def _train(self, kwargs):
|
||||
self.model_ = BayesianNetwork(
|
||||
self.dag_.edges(), show_progress=self.show_progress
|
||||
)
|
||||
@@ -119,6 +119,7 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
||||
self.dataset_,
|
||||
estimator=BayesianEstimator,
|
||||
prior_type="K2",
|
||||
state_names=kwargs["state_names"],
|
||||
)
|
||||
|
||||
def predict(self, X):
|
||||
@@ -227,7 +228,7 @@ class TAN(BayesBase):
|
||||
|
||||
def _check_params(self, X, y, kwargs):
|
||||
self.head_ = 0
|
||||
expected_args = ["class_name", "features", "head"]
|
||||
expected_args = ["class_name", "features", "head", "state_names"]
|
||||
X, y = self._check_params_fit(X, y, expected_args, kwargs)
|
||||
if self.head_ == "random":
|
||||
self.head_ = random.randint(0, len(self.features_) - 1)
|
||||
@@ -253,7 +254,7 @@ class KDB(BayesBase):
|
||||
)
|
||||
|
||||
def _check_params(self, X, y, kwargs):
|
||||
expected_args = ["class_name", "features"]
|
||||
expected_args = ["class_name", "features", "state_names"]
|
||||
return self._check_params_fit(X, y, expected_args, kwargs)
|
||||
|
||||
def _build(self):
|
||||
@@ -327,7 +328,7 @@ class AODE(BayesBase, BaseEnsemble):
|
||||
)
|
||||
|
||||
def _check_params(self, X, y, kwargs):
|
||||
expected_args = ["class_name", "features"]
|
||||
expected_args = ["class_name", "features", "state_names"]
|
||||
return self._check_params_fit(X, y, expected_args, kwargs)
|
||||
|
||||
def _build(self):
|
||||
|
Reference in New Issue
Block a user