From 3812d271e5addfd69e2f26a81abfec080af2e2d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ricardo=20Monta=C3=B1ana?= Date: Thu, 15 Jun 2023 14:28:35 +0200 Subject: [PATCH] Add BoostAODE initial model --- bayesclass/__init__.py | 1 + bayesclass/clfs.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/bayesclass/__init__.py b/bayesclass/__init__.py index 803f64a..948ba7d 100644 --- a/bayesclass/__init__.py +++ b/bayesclass/__init__.py @@ -18,4 +18,5 @@ __all__ = [ "AODE", "KDBNew", "AODENew", + "BoostAODE", ] diff --git a/bayesclass/clfs.py b/bayesclass/clfs.py index db40900..723c8d0 100644 --- a/bayesclass/clfs.py +++ b/bayesclass/clfs.py @@ -418,7 +418,13 @@ def build_spodes(features, class_name): class SPODE(BayesBase): def _check_params(self, X, y, kwargs): - expected_args = ["class_name", "features", "state_names"] + expected_args = [ + "class_name", + "features", + "state_names", + "sample_weight", + "weighted", + ] return self._check_params_fit(X, y, expected_args, kwargs) @@ -775,3 +781,29 @@ class Proposal(BaseEstimator): # np.array(self.state_names_[self.features_[i]]), # ) # raise ValueError("Discretization error") + + +class BoostAODE(AODE): + def fit(self, X, y, **kwargs): + self.n_features_in_ = X.shape[1] + self.feature_names_in_ = kwargs.get( + "features", default_feature_names(self.n_features_in_) + ) + self.class_name_ = kwargs.get("class_name", "class") + # build estimator + self._validate_estimator() + self.X_ = X + self.y_ = y + self.estimators_ = [] + self._train(kwargs) + # To keep compatiblity with the benchmark platform + self.fitted_ = True + self.nodes_leaves = self.nodes_edges + return self + + def _train(self, kwargs): + for dag in build_spodes(self.feature_names_in_, self.class_name_): + estimator = clone(self.estimator_) + estimator.dag_ = estimator.model_ = dag + estimator.fit(self.X_, self.y_, **kwargs) + self.estimators_.append(estimator)