mirror of
https://github.com/Doctorado-ML/bayesclass.git
synced 2025-08-21 10:35:54 +00:00
Add fit_params to model fit
This commit is contained in:
@@ -106,12 +106,12 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
# Build the DAG
|
# Build the DAG
|
||||||
self._build()
|
self._build()
|
||||||
# Train the model
|
# Train the model
|
||||||
self._train()
|
self._train(kwargs)
|
||||||
self.fitted_ = True
|
self.fitted_ = True
|
||||||
# Return the classifier
|
# Return the classifier
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _train(self):
|
def _train(self, kwargs):
|
||||||
self.model_ = BayesianNetwork(
|
self.model_ = BayesianNetwork(
|
||||||
self.dag_.edges(), show_progress=self.show_progress
|
self.dag_.edges(), show_progress=self.show_progress
|
||||||
)
|
)
|
||||||
@@ -119,6 +119,7 @@ class BayesBase(BaseEstimator, ClassifierMixin):
|
|||||||
self.dataset_,
|
self.dataset_,
|
||||||
estimator=BayesianEstimator,
|
estimator=BayesianEstimator,
|
||||||
prior_type="K2",
|
prior_type="K2",
|
||||||
|
state_names=kwargs["state_names"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
@@ -227,7 +228,7 @@ class TAN(BayesBase):
|
|||||||
|
|
||||||
def _check_params(self, X, y, kwargs):
|
def _check_params(self, X, y, kwargs):
|
||||||
self.head_ = 0
|
self.head_ = 0
|
||||||
expected_args = ["class_name", "features", "head"]
|
expected_args = ["class_name", "features", "head", "state_names"]
|
||||||
X, y = self._check_params_fit(X, y, expected_args, kwargs)
|
X, y = self._check_params_fit(X, y, expected_args, kwargs)
|
||||||
if self.head_ == "random":
|
if self.head_ == "random":
|
||||||
self.head_ = random.randint(0, len(self.features_) - 1)
|
self.head_ = random.randint(0, len(self.features_) - 1)
|
||||||
@@ -253,7 +254,7 @@ class KDB(BayesBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_params(self, X, y, kwargs):
|
def _check_params(self, X, y, kwargs):
|
||||||
expected_args = ["class_name", "features"]
|
expected_args = ["class_name", "features", "state_names"]
|
||||||
return self._check_params_fit(X, y, expected_args, kwargs)
|
return self._check_params_fit(X, y, expected_args, kwargs)
|
||||||
|
|
||||||
def _build(self):
|
def _build(self):
|
||||||
@@ -327,7 +328,7 @@ class AODE(BayesBase, BaseEnsemble):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_params(self, X, y, kwargs):
|
def _check_params(self, X, y, kwargs):
|
||||||
expected_args = ["class_name", "features"]
|
expected_args = ["class_name", "features", "state_names"]
|
||||||
return self._check_params_fit(X, y, expected_args, kwargs)
|
return self._check_params_fit(X, y, expected_args, kwargs)
|
||||||
|
|
||||||
def _build(self):
|
def _build(self):
|
||||||
|
Reference in New Issue
Block a user